程式扎記: [ ML In Action ] Tree-based regression : The CART algorithm (1)

標籤

2013年1月6日 星期日

[ ML In Action ] Tree-based regression : The CART algorithm (1)


Preface:
在前一章 Predicting numeric values : regression - Linear regression 我們學會了如何從選定的 features 與 training data 中找出方程式來預測可能的數值輸出. 在這個過程中使用的方程式永遠就只有一個. 但在一些較複雜的情境下如預測的數值可能隨著不同的區段會有不同的趨勢時, 只使用一個方式程可能無法含括所有的趨勢變化時, 這邊的 Tree-based regression 便可以派上用場. 例如你要預測下圖 data set 分布, 可以清楚發現只使用一條方程式是不足以描述具有兩個區段趨勢的資料:


這裡的使用的代碼都可以在 這裡 下載.

Locally modeling complex data:
在 Tree-based regression, 它透過適當的分析針對每個 Feature 進行切割以產生不同區段的預測結果 (可能是值, 也可能是方式程). 這邊使用的演算法是 CART (Classification and regression trees). 而底下是對 Tree-based regression 的簡單特性說明:
Pros: Fits complex, nonlinear data
Cons: Difficult to interpret results
Works with: Numeric values, nominal values

在 CART 演算法可以產生兩種 Trees, 一種是 regression tree (leaf 直接是預測結果); 另一種是 model tree (leaf 是預測方程式). 而在產生 Tree 的過程, 會對不同的 feature 從所有出現過的值挑選一個具有最小 error 的結果進行二元切割, 這邊的 error 的計算針對不同的 Tree 類型會有不同的函數來負責 (regErr 或 modelErr), 稍後會在代碼中詳細說明.

接著來看看建立的 Tree 每個 node (不包含 leaf) 所具有的 attribute:
* Feature —A symbol representing the feature split on for this tree.
* Value—The value of the feature used to split.
* Right—The right subtree; this could also be a single value if the algorithm decides we don’t need another split.
* Left—The left subtree similar to the right subtree.

而整個 Tree 建立過程的 pseudo code 如下:


而在開始進入 CART 演算法的代碼前, 下面代碼讓我們得以載入 training data set (loadDataSet) 與實作上述 Tree 的 split pseudo code (createTree):
  1. def loadDataSet(fileName):  
  2.     """ Loading training data  
  3.     File contain training data and each column is separated with Tab.  
  4.     Beside the last column as prediction result, each column is the value of each feature.  
  5.     """  
  6.     dataMat = []  
  7.     fr = open(fileName)  
  8.     for line in fr.readlines():  
  9.         curLine = line.strip().split('\t')  
  10.         fltLine = map(float, curLine)  
  11.         dataMat.append(fltLine)  
  12.     return dataMat  
  13.   
  14. def binSplitDataSet(dataSet, feature, value):  
  15.     """ Split data set based on given [feature] from given [value].  
  16.     This API split data set from [feature]=[value] then return split result as tuple.  
  17.     tuple[0] has value equal or less than [value]; tuple[1] has value larger than [value]  
  18.     """  
  19.     mat0 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]  
  20.     mat1 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]  
  21.     return mat0,mat1 # left/right  
  22.   
  23. def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):  
  24.     """ Create Tree  
  25.     Based on the given [leafType], [errType], this API will create regression tree or model tree.  
  26.     """  
  27.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)  
  28.     if feat == None: return val                      
  29.     retTree = {}  
  30.     retTree['spInd'] = feat  
  31.     retTree['spVal'] = array(val).flatten().tolist()[0]  
  32.     lSet, rSet = binSplitDataSet(dataSet, feat, val)  
  33.     retTree['left'] = createTree(lSet, leafType, errType, ops)  
  34.     retTree['right'] = createTree(rSet, leafType, errType, ops)  
  35.     return retTree  
上面代碼還少了函數 chooseBestSplit, 我們會在接下來說明它.

Using CART for regression:
在 regression tree 的 leaf 即是 prediction, 也就是透過不斷的 split 最終我們得到一群具有較接近預測值的 sub data set. 接著我們計算這些 sub data set 的平均值當作預測值. 而這是 regression 的中心思想. 但到目前為止我們還沒說明如何選擇最佳的 split point. 這邊使用 mean squared error 並計算 total error (計算 data set 結果的 variance 並計算 mean squared error 後乘與 data set 的數目), 並挑選具有最小的 error 的點作為 split 的依據. 這邊我們實作 chooseBestSplit() 幫我們挑選最佳的 split point. 實作的 pseudo code 如下:


接著底下是實作的代碼:
  1. def regLeaf(dataSet):  
  2.     """Generates the model for a leaf node  
  3.     When chooseBestSplit() decides that you no longer should split the data, it will call regLeaf() to get a model for the leaf.   
  4.     The model in a regression tree is the mean value of the target variables.  
  5.     """  
  6.     return array(mean(dataSet[:,-1])).flatten().tolist()[0]  
  7.     # return mean(dataSet[:,-1])  
  8.   
  9. def regErr(dataSet):  
  10.     """  This function returns the squared error of the target variables in a given dataset.  
  11.     This function returns the squared error of the target variables in a given dataset.  
  12.     For more about variance, please refer to: http://en.wikipedia.org/wiki/Variance  
  13.     """  
  14.     return var(dataSet[:,-1]) * shape(dataSet)[0]  
  15.   
  16. def chooseBestSplit(dataSet, leafType, errType, ops):  
  17.     """ Choose Best Point To Split The DataSet  
  18.     This function finds the best place to split the dataset. The pseudo code:  
  19.     ------------------------------------------------------------------------  
  20.         For every feature:  
  21.         For every unique value:  
  22.             Split the dataset it two  
  23.             Measure the error of the two splits  
  24.             If the error is less than bestError   
  25.         Return bestSplit feature and threshold  
  26.     ------------------------------------------------------------------------  
  27.     1.  The function chooseBestSplit() starts out by assigning the values of ops to tolS and tolN. These two values are user-defined settings   
  28.         that tell the function when to quit creating new splits.  
  29.     2.  The next thing chooseBestSplit() does is check the number of unique values by creating a set from all the target variables.   
  30.         If this set is length 1, then you don't need to try to split the set and you can return.  
  31.     3.  Next, chooseBestSplit() measures the size of the dataset and measures the error on the existing dataset.   
  32.         This error will be checked against new values of the error to see if splitting reduces the error.   
  33.     4.  A few variables that will be used to find the best split are created and initialized.   
  34.         You next iterate over all the possible features and all the possible values of those features to find the best split.  
  35.     5.  A few variables that will be used to find the best split are created and initialized.   
  36.         You next iterate over all the possible features and all the possible values of those features to find the best split.  
  37.     6.  The best split is determined by the lowest error of the sets after the split. If splitting the dataset improves the error by only   
  38.         a small amount, you choose not to split and create a leaf node.  
  39.     """  
  40.     tolS = ops[0] # Tolerate Error to stop generating tree   
  41.     tolN = ops[1] # Tolerate minimum split data set size to stop generating tree  
  42.     # Exit if all values are equal  
  43.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  
  44.         return None, leafType(dataSet)  
  45.     m,n = shape(dataSet)  
  46.     S = errType(dataSet)  
  47.     bestS = inf; bestIndex = 0; bestValue = 0  
  48.     for featIndex in range(n-1):  
  49.         for splitVal in set(dataSet[:,featIndex]): # Specific feature with all possible values  
  50.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  
  51.             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue  
  52.             newS = errType(mat0) + errType(mat1)  
  53.             if newS < bestS:  
  54.                 bestIndex = featIndex  
  55.                 bestValue = splitVal  
  56.                 bestS = newS  
  57.         if (S - bestS) < tolS: # Exit if low error reduction   
  58.             return None, leafType(dataSet)  
  59.         mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)  
  60.         if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  
  61.             return None, leafType(dataSet)  
  62.         return bestIndex, bestValue  
接著我們可以使用下面的代碼來進行測試, 產生的圖為我們 training data (ex00.txt) 的分布:
  1. myDat = loadDataSet('ex00.txt')  
  2. myMat = mat(myDat)  
  3. tree = createTree(myMat)  
  4. print(tree)  
  5. fig = plt.figure()  
  6. ax = fig.add_subplot(111)  
  7. xarr = [123]  
  8. yarr = [246]  
  9. ax.scatter(array(myMat[:,0]), array(myMat[:,1]))  
  10. plt.show()  
執行結果在 console 會有切割後的 tree 的輸出:
{'spInd': 0, 'spVal': 0.48813, 'right': 1.018096767241379, 'left': -0.04465028571428573}

說明在我們的 training data set 中在 feature-0 最佳的切割點為 0.48813; 而在下圖右邊的 data 會得到預測值=1.018096767241379; 左邊的預測值= -0.04465028571428573:


這樣的 training data set 好像沒法看出 regression tree 的威力. 沒關係, 我們使用複雜一點的 data(ex0.txt):
  1. myMat = mat(loadDataSet2('ex0.txt'))  
  2. tree = createTree(myMat)  
  3. print(tree)  
  4. fig = plt.figure()  
  5. ax = fig.add_subplot(111)  
  6. ax.scatter(array(myMat[:,0]), array(myMat[:,1]))  
  7. plt.show()  
執行結果會在 console 得到 tree 的輸出:
{'spInd': 0, 'spVal': 0.39435, 'right': {'spInd': 0, 'spVal': 0.582002, 'right': {'spInd': 0, 'spVal': 0.797583, 'right': 3.9871632000000004, 'left': 2.9836209534883724}, 'left': 1.9800350714285717}, 'left': {'spInd': 0, 'spVal': 0.197834, 'right': 1.0289583666666664, 'left': -0.023838155555555553}}

而 data 的分佈如下圖, 可以知道複雜的 data 會得到複雜的 Tree:


Tree pruning:
有時候太複雜的 Tree 會造成 overfitting 的效果, 因此我們需要對建立的 Tree 進行 cross validation 以確保我們建立的 Tree 在其他的 test data 仍保有較高的 prediction rate. 而這個過程我們在這裡稱之為 pruning. 事實上在 chooseBestSplit() 過程中我們已經使用 prepruning 的技巧. 透過設定參數 ops:
  1. tolS = ops[0] # Tolerate Error to stop generating tree   
  2. tolN = ops[1] # Tolerate minimum split data set size to stop generating tree  
某種程度我們已經在避免 overfitting. 但是問題在於我們如何決定 ops 參數!

因此有另一類的 pruning 方法稱為 Postpruning. 在這類的 pruning 我們會先切割 data set 為 training data set (大部分) 與 testing data set (少部分). 我們會使用 training data set 來建立 Tree, 並使用 testing data set 來 evaluate 我們建立的 Tree. 在 evaluation 的過程中, 如果發現 left/right nodes 都是預測結果 (leaf), 則計算有 merge 與 不 merge 下得到的誤差平方(預測值於實際值的差)的合. 如果發現 merge 有較小的誤差平方合, 則將 left/right nodes 進行 merge. 這裡我們實作方法 prune 來完成上述的動作, 其 pseudo code 如下:


實作代碼如下:
  1. def isTree(obj):  
  2.     """Tests if  is a tree  
  3.     The function isTree() tests if a variable is a tree. It returns a Boolean type.   
  4.     You can use this to find out when you've found a branch with only leaf nodes.  
  5.     """  
  6.     return type(obj).__name__ == 'dict'  
  7.   
  8. def getMean(tree):  
  9.     """ Collapse a tree with mean value of that tree.  
  10.     The function getMean() is a recursive function that descends a tree until it hits only leaf nodes.   
  11.     When it finds two leaf nodes, it takes the average of these two nodes. This function collapses a tree.   
  12.     """  
  13.     if isTree(tree['right']): tree['right'] = getMean(tree['right'])  
  14.     if isTree(tree['left']): tree['left'] = getMean(tree['left'])  
  15.     return (tree['left']+tree['right'])/2.0  
  16.   
  17. def prune(tree, testData):  
  18.     """ Prune the tree based on given test/training set  
  19.     The pseudo code:  
  20.     ----------------------------------------------------------------------  
  21.     Split the test data for the given tree:  
  22.         If the either split is a tree: call prune on that split  
  23.         Calculate the error associated with merging two leaf nodes  
  24.         Calculate the error without merging  
  25.         If merging results in lower error then merge the leaf nodes  
  26.     ----------------------------------------------------------------------  
  27.     1. The first thing you do in prune() is check to see if the test data is empty.  
  28.     2. The function prune() gets called recursively and splits the data based on the tree.  
  29.     3. Our tree is generated with a different set of data from our test data, and there will be instances where the test data   
  30.        doesn't contain values in the same range as the original dataset.   
  31.        In this case,  We'll assume it's overfit and prune the tree.    
  32.     4. Next, you test to see if either branch is a tree. If so, you attempt to prune it by calling  prune  on  that  branch.    
  33.        After  you've  attempted  to  prune  the  left  and  right branches, you test to see if they're still trees.   
  34.        If the two branches aren't trees, then they can be merged.  
  35.     5. You split the data and measure the error. If the error from merging the two branches is less than the error from not merging,   
  36.        you merge the branches. If there's no measurable benefit to merging, you return the original tree.    
  37.     """  
  38.     # Collapse tree if no test data  
  39.     if shape(testData)[0] == 0return getMean(tree)  
  40.     if (isTree(tree['right']) or isTree(tree['left'])):  
  41.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])  
  42.     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)  
  43.     if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)  
  44.     # Reach leaves on both left/right nodes  
  45.     if not isTree(tree['left']) and not isTree(tree['right']):  
  46.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])  
  47.         errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'],2))  
  48.         treeMean = (tree['left']+tree['right'])/2.0  
  49.         errorMerge = sum(power(testData[:,-1] - treeMean,2))  
  50.         if errorMerge < errorNoMerge:  
  51.             #print "merging"  
  52.             return treeMean  
  53.         elsereturn tree  
  54.     elsereturn tree  
底下為測試代碼:
  1. print("\t[Info] API:prune() testing:")  
  2. print("\tLoading 'ex2.txt'...")  
  3. myMat2 = mat(loadDataSet('ex2.txt'))  
  4. print("\tCreating tree with ops=(0,1)...")  
  5. myTree = createTree(myMat2, ops=(0,1))  
  6. print("\tOriginal tree:\n{0}".format(myTree))  
  7. print("\tLoading testing set='ex2test.txt'...")  
  8. myMat2Test = mat(loadDataSet('ex2test.txt'))  
  9. print("\tPruning tree...")  
  10. prune(myTree, myMat2Test)  
  11. print("\tResulting Tree:\n{0}".format(myTree))  
執行結果可以發現 pruning 後的 tree 比較小.

Supplement:
[ ML In Action ] Decision Tree Construction
Numpy > numpy.nonzero(a) : Return the indices of the elements that are non-zero.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The corresponding non-zero values can be obtained with:

This message was edited 38 times. Last update was at 06/01/2013 16:07:33

沒有留言:

張貼留言

網誌存檔

關於我自己

我的相片
Where there is a will, there is a way!