Decision trees
Decision trees are a widely used models for classification and regression tasks. Essentially, they learn a hierarchy of “if-else” questions, leading to a decision. These questions are similar to the questions you might ask in a game of twenty questions. Imagine you want to distinguish between the following four animals: bears, hawks, penguins and dolphins. Your goal is to get to the right answer b] asking as few if-else questions as possible. You might start off by asking whether the animal has feathers, a question that narrows down your possible animals to just two animals. If the answer is yes, you can ask another question that could help you distinguish
between hawks and penguins. For example, you could ask whether or not the animal can fly. If the animal doesn’t have feathers, your possible animal choices are dolphins and bears, and you will need to ask a question to distinguish between these two animals, for example, asking whether the animal has fins.
This series of questions can be expressed as a decision tree, as shown in Figure animal_tree.
Figure animal_tree
In this illustration, each node in the tree either represents a question, or a terminal node (also called a leaf) which contains the answer. The edges connect the answers to a question with the next question you would ask. In machine learning parlance, we built a model to distinguish between four classes of animals (hawks, penguins, dolphins and bears) using the three features “has feathers”, “can fly” and “has fins”.
Instead of building these models by hand, we can learn them from data using supervised learning.
Building Decision Trees
Let’s go through the process of building a decision tree for the 2d classification dataset shown at the top of Figure tree_building. The dataset consists of two half-moon shapes of blue and red points, consisting of 75 data points each. We will refer to this dataset as two_moons.
Figure tree_building
Learning a decision tree means learning a sequence of if/else questions that gets us to the true answer most quickly. In the machine learning setting, these questions are called tests (not to be confused with the test set, which is the data we use to test to see how generalizable our model is). Usually data does not come in the form of binary yes/no features as in the animal example, but is instead represented as continuous features such as in the 2d dataset shown in the figure. The tests that are used on continuous data are of the from “is feature i larger than value a”.
To build a tree, the algorithm searches over all possible tests, and finds the one that is most informative about the target variable.
The second row in Figure tree_building shows the first test that is picked. Splitting the dataset vertically at x[1]=0.2372 yields the most information; it best separates the blue points from the red points. The top node, also called the root, represents the whole dataset, consisting of 75 red and 75 blue points. The split is done by testing whether x[1] <= 0.2372, indicated by a black line. If the test is true, a point is assigned to the left node, which contains 8 blue points and 58 red points. Otherwise the point is assigned to the right node, which contains 67 red points and 17 blue points. These two nodes correspond to the top and bottom region shown in Figure tree_building.
Even though the first split did a good job of separating the blue and red points, the bottom region still contains red points, and the top region still contains blue points. We can build a more accurate model by repeating the process of looking for the best test in both regions. Figure tree_building shows that the most informative next split for the left and the right region are based on x[0]. This recursive process yields a binary tree of decisions, with each node containing a test.
Alternatively, you can think of each test as splitting the part of the data that is currently considered along one axis. This yields a view of the algorithm as building a hierarchical partition. As each test concerns only a single feature, the regions in the resulting partition always have axis-parallel boundaries.
The recursive partitioning of the data is usually repeated until each region in the partition (each leaf in the decision tree) only contains a single target value (a single class or a single regression value). A leaf of the tree containing only one target value is called pure.
A prediction on a new data point is made by checking which region of the partition of the feature space the point lies in, and then predicting the majority target (or the single target in the case of pure leaves) in that region. The region can be found by traversing the tree from the root and going left or right, depending on whether the test is fulfilled or not.
Controlling complexity of Decision Trees
Typically, building a tree as described above, and continuing until all leaves are pure leads to models that are very complex and highly overfit to the training data. The presence of pure leaves mean that a tree is 100% accurate on the training set; each data point in the training set is in a leaf that has the correct majority class. The overfitting can be seen on below figure. You can see the regions determined to be red in the middle of all the blue points. On the other hand, there is a small strip of blue around the single blue point to the very right. This is not how one would imagine the decision boundary to look, and the decision boundary focuses a lot on single outlier points that are far away from the other points in that class.
There are two common strategies to prevent overfitting: stopping the creation of the tree early, also called pre-pruning, or building the tree but then removing or collapsing nodes that contain little information, also called post-pruning or just pruning. Possible criteria for pre-pruning include limiting the maximum depth of the tree, limiting the maximum number of leaves, or requiring a minimum number of points in a node to keep splitting it.
Decision trees in scikit-learn are implemented in the DecisionTreeRegressor and DecisionTreeClassifier classes. Scikit-learn only implements pre-pruning, not postpruning. Let’s look at the effect of pre-pruning in more detail on the breast cancer dataset. As always, we import the dataset and split it into a training and test part. Then we build a model using the default setting of fully developing the tree (growing the tree until all leaves are pure). We fix the random_state in the tree, which is used for tie-breaking internally.
- ch2_t19.py
- import mglearn
- import matplotlib.pyplot as plt
- from sklearn.model_selection import train_test_split
- from sklearn.tree import DecisionTreeClassifier
- from sklearn.datasets import load_breast_cancer
- cancer = load_breast_cancer()
- X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify=cancer.target, random_state=42)
- tree = DecisionTreeClassifier(random_state=0)
- tree.fit(X_train, y_train)
- print("Accuracy on training set: %f" % tree.score(X_train, y_train))
- print("Accuracy on test set: %f" % tree.score(X_test, y_test))
As expected, the accuracy on the training set is 100% as the leaves are pure. The test-set accuracy is slightly worse than the linear models above, which had around 93% accuracy. Now let’s apply pre-pruning to the tree, which will stop developing the tree before we perfectly fit to the training data.
One possible way is to stop building the tree after a certain depth has been reached. Here we set max_depth=4, meaning only four consecutive questions can be asked:
Limiting the depth of the tree decreases overfitting. This leads to a lower accuracy on the training set, but an improvement on the test set.
Analyzing Decision Trees
We can visualize the tree using the export_graphviz function from the tree module. This writes a file in the dot file format, which is a text file format for storing graphs.
We can read this file and visualize it using the graphviz module (or you can use any program that can read dot files):
- import graphviz
- with open("mytree.dot") as f:
- dot_graph = f.read()
- g= graphviz.Source(dot_graph)
The visualization of the tree provides a great in-depth view of how the algorithm makes predictions, and is a good example of a machine learning algorithm that is easily explained to non-experts. However, even with a tree of depth four, as seen here, the tree can become a bit overwhelming. Deeper trees (depth ten is not uncommon) are even harder to grasp.
One method of inspecting the tree that may be helpful is to find out which path most of the data actually takes. The n_samples shown in each node in the figure gives the number of samples in each node, while value provides the number of samples per class. Following the branches to the right, we see that texture_error <= 0.4732 creates a node that only contains 8 benign but 134 malignant samples. The rest of this side of the tree then uses some finer distinctions to split off these 8 remaining benign samples. Of the 142 samples that went to the right in the initial split, nearly all of them (132) end up in the leaf to the very right; Taking a left at the root, for texture_error > 0.4732, we end up with 25 malignant and 259 benign samples. Nearly all of the benign samples end up in the second leave from the left, with most of the other leaves only containing very few samples.
Feature Importance in trees
Instead of looking at the whole tree, which can be taxing, there are some useful statistics that we can derive properties that we can derive to summarize the workings of the tree. The most commonly used summary is feature importance, which rates how important each feature is for the decision a tree makes. It is a number between 0 and 1 for each feature, where 0 means “not used at all” and 1 means “perfectly predicts the target”.
The feature importances always sum to one.
We can visualize the feature importances in a way that is similar to the way we visualize the coefficients in the linear model:
- ch2_t20.py
- #!/usr/bin/env python
- import mglearn
- import matplotlib.pyplot as plt
- from sklearn.model_selection import train_test_split
- from sklearn.tree import DecisionTreeClassifier
- from sklearn.datasets import load_breast_cancer
- cancer = load_breast_cancer()
- X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify=cancer.target, random_state=42)
- tree = DecisionTreeClassifier(random_state=0, max_depth=4)
- tree.fit(X_train, y_train)
- print("Accuracy on training set: %f" % tree.score(X_train, y_train))
- print("Accuracy on test set: %f" % tree.score(X_test, y_test))
- import os
- dmode = os.environ.get('DISPLAY', '')
- if dmode:
- import matplotlib.pyplot as plt
- import numpy as np
- plt.plot(tree.feature_importances_, 'o')
- plt.xticks(range(cancer.data.shape[1]), cancer.feature_names, rotation=90)
- plt.ylim(0, 1)
- plt.show()
Here, we see that the feature used at the top split (“worst radius”) is by far the most important feature. This confirms our observation in analyzing the tree, that the first level already separates the two classes fairly well. However, if a feature has a low feature_importance, it doesn’t mean that this feature is uninformative. It only means that this feature was not picked by the tree, likely because another feature encodes the same information.
In contrast to the coefficients in linear models, feature importances are always positive, and don’t encode which class a feature is indicative of. The feature importances tell us that worst radius is important, but it does not tell us whether a high radius is indicative of a sample being “benign” or “malignant”. In fact, there might not be such a simple relationship between features and class, as you can see in the example below:
- import mglearn
- tree = mglearn.plots.plot_tree_not_monotone()
- import matplotlib.pyplot as plt
- plt.suptitle("tree_not_monotone")
- plt.show()
The plot shows a dataset with two features and two classes. Here, all the information is contained in X[1], and X[0] is not used at all. But the relation between X[1] and the output class is not monotonous, meaning we cannot say “a high value of X[0] means class red, and a low value means class blue” or the other way around.
While we focuses our discussion here on decision trees for classification, all that was said is similarly true for decision trees for regression, as implemented in DecisionTreeRegressor. Both the usage and the analysis of regression trees are very similar to classification trees, so we won’t go into any more detail here.
Strengths, weaknesses and parameters
Decision trees have two advantages over many of the algorithms we discussed so far: The resulting model can easily be visualized and understood by non-experts (at least for smaller trees), and the algorithms is completely invariant to scaling of the data: As each feature is processed separately, and the possible splits of the data don’t depend on scaling, no preprocessing like normalization or standardization of features is needed for decision tree algorithms. In particular, decision trees work well when you have features that are on completely different scales, or a mix of binary and continuous features.
The main down-side of decision trees is that even with the use of pre-pruning, decision trees tend to overfit, and provide poor generalization performance. Therefore, in most applications, the ensemble methods we discuss below are usually used in place of a single decision tree.
Supplement
* [ ML In Action ] Decision Tree Construction
沒有留言:
張貼留言