程式扎記: [ ML In Action ] Decision Tree Construction

標籤

2012年7月16日 星期一

[ ML In Action ] Decision Tree Construction

Preface : 
在這裡我們將會介紹 decision tree-building 的演算法, 而使用的演算法將會利用 information theory 的原理來決定 decision tree 的分割以得到最佳的結果. 底下是使用 Decision tree 的概述 : 
Decision Trees
Pros: Computationally cheap to use, easy for humans to understand learned results, missing values OK, can deal with irrelevant features.
Cons: Prone to overfitting
Works with: Numeric values, nominal values

Tree Construction : 
在建立 Decision Tree 的過程中, 你將會不斷的對給定的 Data Set 進行 feature 的選用與切割. 假設你選定了某個 Feature, 接著這個 Feature 的 possible values 便成了 Decision Tree 的 Branches, 而整個 Decision Tree 的 nodes 會由 decision blocks (正方形) 與 terminating blocks (橢圓形) 組成. 而 terminating blocks 即為我們最終的分類 (Class). 底下為一 Decision Tree 範例 : 
 

而切割後的 Data Set, 原先被切割的 Feature 已經被固定成某個值, 而切割後的 Set 我們稱為 Sub Data Set, 可以對剩下的 Features 繼續進行切割, 處理程序跟切割前大致不變, 因此我們可以使用 Recursive way 來撰寫這樣的函數 createBranch(), 底下為此函數的 pseudo code : 
  1. Check if every item in the dataset is in the same class:  
  2.     If so return the class label  
  3.     Else  
  4.         find the "best feature" to split the data  
  5.         split the dataset  
  6.         create a branch node  
  7.             for each split  
  8.                 call createBranch and add the result to the branch node  
  9.         return branch node  
在我們開始刻代碼前, 先來看看接下來會做那些事 : 
General approach to decision trees
1. Collect: Any method.
2. Prepare: This tree-building algorithm works only on nominal values, so any continuous values will need to be quantized.
3. Analyze: Any method. You should visually inspect the tree after it is built.
4. Train: Construct a tree data structure.
5. Test: Calculate the error rate with the learned tree.
6. Use: This can be used in any supervised learning task.

這邊我們使用的演算法為 ID3, 透過它你將在每次切割 Feature 時選擇有較好的 Information gain 的 feature 進行, 一直到所有的 terminating node 都有一致的類別 ; 如果到最後還是沒有統一的類別, 就只能由多數的類別決定 terminating node 的類別. 

在開始處理演算法前, 照慣例先來看看我們處理的 Data set. 參考下表, 共有五筆記錄 ; 並提供兩個 features ; 而類別為 Fish or Not Fist : 
 

- Information gain 
在這邊我們使用的演算法會用到 Information theory 來決定每次切割前 Feature 的選定, 而選定的標準為選定的 Feature 有較高的 Information gain. 在 Information theory 中使用 Shannon entropy 或簡稱 entropy 來量測資料集上面帶有的 information. 對於entropy 的詳細定義這邊不會討論, 而只使用定義後的公式 : 
Shannon denoted the entropy H of a discrete random variable X with possible values {x1, ..., xn} and probability mass function p(X) as,

I(X) is itself a random variable, and I is the information content of X

由上面的公式可以推論如果 p(xi) = 1, 則 H(X)=0 ; p(xi) = 0, 則 H(X)=0. 而 entropy 被定義為 "The expected value of the information". 而當 p(xi)=0 或 p(xi)=1 你可以解釋成當你已經可以百分百確定某事, 那麼它將不會帶來任何 information, 故 H(X)=0. 因此在 Decision Tree 的每次 Feature 切割後, entropy 的值會不斷的變小 (因為你越來越接近 classified 的結果). 底下為實作計算 entropy 的函數 :
- Listing 3.1 Function to calculate the Shannon entropy of a dataset 
  1. def calcShannonEnt(dataSet):  
  2.         """ Function to calculate the Shannon entropy of a dataset"""  
  3.         numEntries = len(dataSet)  
  4.         labelCounts = {}  
  5.         for featVec in dataSet:  
  6.                 currentLabel = featVec[-1]  
  7.                 labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1  
  8.         shannonEnt = 0.0  
  9.         for key in labelCounts:  
  10.                 prob = float(labelCounts[key]) / numEntries  
  11.                 shannonEnt -= prob * log(prob, 2)  
  12.         return shannonEnt  
為了來測試一下函數 calcShannonEnt(), 我們寫了另一個函數來取得測試用的 data set : 
  1. def createTestDataSet():  
  2.         dataSet = [ [11'yes'],  
  3.                     [11'yes'],  
  4.                     [10'no'],  
  5.                     [01'no'],  
  6.                     [01'no']]  
  7.         labels = ['no surfacing''flippers']  
  8.         return dataSet, labels  
接著你便可以如下測試或計算 (上面的代碼撰寫於 tree.py 中.) : 
 

- Splitting the dataset 
在處理 dataset 過程中, 我們必需針對 Feature 對 dataset 進行切割 ; 而切割後 sub dataset 的該 Feature 那欄將被移除 : 
- Listing 3.2 
  1. def splitDataSet(dataSet, axis, value):  
  2.         """ Dataset splitting on a given feature  
  3.         - Argument  
  4.           * dataSet : 原始 data set  
  5.           * axis : 切割 Feature 的 column index.  
  6.           * value : 切割 Feature 的某個值."""  
  7.         retDataSet = []  
  8.         for featVec in dataSet:  
  9.                 if featVec[axis] == value:  
  10.                         reducedFeatVec = featVec[:axis]  
  11.                         reducedFeatVec.extend(featVec[axis+1:])  
  12.                         retDataSet.append(reducedFeatVec)  
  13.         return retDataSet  
接著你便能如下測試切割函數的代碼 : 
 

現在我們已經能切割 data set 與計算 entropy, 而每次要挑選 Feature 來進行切割的條件會選擇切割後有較大的 Information gain. 也就是maximum(IG(features)) ; IG(feature) = ig_orig - ig_new. (ig_orig=某 Feature 切割前的entropy; ig_new=某 Feature 切割後的entropy). 實作代碼如下 : 
- Listing 3.3 
  1. def chooseBestFeatureToSplit(dataSet):  
  2.         """ Choosing the best feature to split on  
  3.         - Argument  
  4.           * dataSet: 原始 data set  
  5.         - Return  
  6.           返回有最大 Information gain 的 Feature column."""  
  7.         numFeatures = len(dataSet[0]) - 1 # 最後一個欄位是 Label  
  8.         baseEntropy = calcShannonEnt(dataSet)  
  9.         bestInfoGain = 0.0 ; bestFeature = -1  
  10.         for i in range(numFeatures):  
  11.                 featList = [example[i] for example in dataSet]  
  12.                 uniqueVals = set(featList)  
  13.                 newEntropy = 0.0  
  14.                 for value in uniqueVals:  
  15.                         subDataSet = splitDataSet(dataSet, i, value)  
  16.                         prob = len(subDataSet) / float(len(dataSet))  
  17.                         newEntropy += prob * calcShannonEnt(subDataSet)  
  18.                 infoGain = baseEntropy - newEntropy  # Information gain: Split 會讓 data 更 organized, 所以 Entropy 會變小.  
  19.                 if infoGain > bestInfoGain:  
  20.                         bestInfoGain = infoGain  
  21.                         bestFeature = i  
  22.         return bestFeature  
接著你可以如下測試 : 
>>> reload(tree) # 重新載入 tree.py

>>> myDat, labels = tree.createTestDataSet() # 取回 test data set
>>> tree.chooseBestFeatureToSplit(myDat)
# 說明第一次切割選擇 Column0 的 Feature!

- Recursively building the tree 
如果切割後出現有一致的 Class (Fish or Not Fish), 則該 Branch 的切割算是完成 ; 如果所有的 Feature 都切割完了, 仍然沒有一致的類別, 則只好採多數決的方法. 通常切割到最後, data set 的欄位會只剩下類別標籤, 可以透過下面方法反回多數的類別 : 
  1. def majorityCnt(classList):  
  2.         """ Choose majority class and return.  
  3.         - Argument  
  4.           * classList: 類別的 List.  
  5.         - Return  
  6.           返回 List 中出現最多次的類別."""  
  7.         classCnt = {}  
  8.         for vote in classList:  
  9.                 classCnt[vote] = classCnt.get(vote, 0) + 1  
  10.         sortedClassCnt = sorted(classCnt.iteritems(),  
  11.                                 key = operator.itemgetter(1),  
  12.                                 reverse=True)  
  13.         return sortedClassCnt[0][0]  
到目前為止我們所需要的功能都已經具備, 接著我們要寫一個 Recursive 函數來產出如下的 Decision Tree : 
 

實作該 Recursive 函數代碼如下 : 
- Listing 3.4 
  1. def createTree(dataSet, labels):  
  2.         """ Tree-building code.  
  3.         - Argument  
  4.           * dataSet: 原始 data set  
  5.           * labels: 對應 class 標籤的文字說明.  
  6.         - Return  
  7.           返回建立的 Decision Tree."""  
  8.         classList = [example[-1for example in dataSet]  
  9.         if classList.count(classList[0]) == len(classList):     # 所有的 class 都是同一類.  
  10.                 return classList[0]  
  11.         if len(dataSet[0]) == 1:                                # 已經沒有 feature 可以 split 了.  
  12.                 return majorityCnt(classList)  
  13.   
  14.         bestFeat = chooseBestFeatureToSplit(dataSet)  
  15.         bestFeatLabel = labels[bestFeat]  
  16.         myTree = {bestFeatLabel:{}}  
  17.         del(labels[bestFeat])  
  18.         featValues = [example[bestFeat] for example in dataSet]  
  19.         uniqueVals = set(featValues)  
  20.         for value in uniqueVals:  
  21.                 subLabels = labels[:]  
  22.                 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  
  23.         return myTree  
可以如下測試代碼 : 
>>> reload(tree)

>>> myDat, labels = tree.createTestDataSet()
>>> myTree = tree.createTree(myDat, labels)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

Plotting trees in Python with Matplotlib annotations : 
雖然 Decision Tree 的資料以經可以成功 train 出來, 不過不夠直覺. 因此我們將透過 Matplotlib 套件幫我們將 Decision Tree 視覺化, 以提高 Human readability. 完整代碼如下 : 
- Listing 3.7 
  1. import matplotlib.pyplot as plt  
  2.   
  3. decisionNode = dict(boxstyle="sawtooth", fc="0.8")  
  4. leafNode = dict(boxstyle="round4", fc="0.8")  
  5. arrow_args = dict(arrowstyle="<-")  
  6.   
  7. def getNumLeafs(myTree):  
  8.     numLeafs = 0  
  9.     firstStr = myTree.keys()[0]  
  10.     secondDict = myTree[firstStr]  
  11.     for key in secondDict.keys():  
  12.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  13.             numLeafs += getNumLeafs(secondDict[key])  
  14.         else:   numLeafs +=1  
  15.     return numLeafs  
  16.   
  17. def getTreeDepth(myTree):  
  18.     maxDepth = 0  
  19.     firstStr = myTree.keys()[0]  
  20.     secondDict = myTree[firstStr]  
  21.     for key in secondDict.keys():  
  22.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  23.             thisDepth = 1 + getTreeDepth(secondDict[key])  
  24.         else:   thisDepth = 1  
  25.         if thisDepth > maxDepth: maxDepth = thisDepth  
  26.     return maxDepth  
  27.   
  28. def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
  29.     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  
  30.              xytext=centerPt, textcoords='axes fraction',  
  31.              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )  
  32.       
  33. def plotMidText(cntrPt, parentPt, txtString):  
  34.     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
  35.     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
  36.     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)  
  37.   
  38. def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on  
  39.     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree  
  40.     depth = getTreeDepth(myTree)  
  41.     firstStr = myTree.keys()[0]     #the text label for this node should be this  
  42.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
  43.     plotMidText(cntrPt, parentPt, nodeTxt)  
  44.     plotNode(firstStr, cntrPt, parentPt, decisionNode)  
  45.     secondDict = myTree[firstStr]  
  46.     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
  47.     for key in secondDict.keys():  
  48.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes     
  49.             plotTree(secondDict[key],cntrPt,str(key))        #recursion  
  50.         else:   #it's a leaf node print the leaf node  
  51.             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
  52.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  53.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  54.     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
  55. #if you do get a dictonary you know it's a tree, and the first element will be another dict  
  56.   
  57. def createPlot(inTree):  
  58.     fig = plt.figure(1, facecolor='white')  
  59.     fig.clf()  
  60.     axprops = dict(xticks=[], yticks=[])  
  61.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks  
  62.     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  63.     plotTree.totalW = float(getNumLeafs(inTree))  
  64.     plotTree.totalD = float(getTreeDepth(inTree))  
  65.     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
  66.     plotTree(inTree, (0.5,1.0), '')  
  67.     plt.show()  
  68.   
  69. def retrieveTree(i):  
  70.     listOfTrees =[{'no surfacing': {0'no'1: {'flippers': {0'no'1'yes'}}}},  
  71.                   {'no surfacing': {0'no'1: {'flippers': {0: {'head': {0'no'1'yes'}}, 1'no'}}}}  
  72.                   ]  
  73.     return listOfTrees[i]  
接著你便可以使用函數 createPlot() 將我們 train 出來的 Decision Tree 畫出來. 測試如下 : 
>>> import tree, treePlotter
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.createPlot(myTree)

執行後產生下圖 : 
 

Testing and storing the classifier : 
目前 createTree() 已經能幫我們建立 train 完後的 Decision tree, 接著我們要撰寫可以透過傳入的 Decision tree 與 sample , 而返回 classified 後的結果. 代碼如下 : 
- Listing 3.8 
  1. def classify(inputTree, featLabels, testVec):  
  2.         """ Classification function for an existing decision tree"""  
  3.         firstStr = inputTree.keys()[0]                                          # 取出第一個 decision block  
  4.         secondDict = inputTree[firstStr]                                        # 取出對應 decision block 的 sub decision tree.  
  5.         featIndex = featLabels.index(firstStr)                                  # 取出該 decision block 對應 feature 的欄位.  
  6.         for key in secondDict.keys():  
  7.                 if testVec[featIndex] == key:                                   # 如果 testVec 在該 feature 的值等於 subdecision tree 的 key  
  8.                         if type(secondDict[key]).__name__ == 'dict':            # 如果該 feature value 的 branch 是 tree, 則繼續 classify.  
  9.                                 classLabel = classify(secondDict[key], featLabels, testVec)  
  10.                         elsereturn secondDict[key]                            # 如果該 feature value 的 branch 是 class, 則立即返回.  
  11.         return classLabel  
接著我們可以如下測試 : 
>>> import tree, treePlotter
>>> myDat, labels = tree.createTestDataSet()
>>> myTree = treePlotter.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> tree.classify(myTree, labels, [1, 0]) # surfacing & no flippers
'no'
>>> tree.classify(myTree, labels, [1, 1]) # surfacing & has flippers
'yes'

- Use: persisting the decision tree 
這邊我們希望已經建好的 Decision Tree 離開程式後, 下次能夠重覆使用而不用重新 train. 這個需求可以藉由 Python module pickle 達成. 因此下面我們使用函數 storeTree() 將 decision tree 存成檔案提供下次使用 ; 函數 grabTree() 則用來載入 storeTree() 輸出的檔案 : 
- Listing 3.9 
  1. def storeTree(inTree, filename):  
  2.         import pickle  
  3.         fw = open(filename, 'w')  
  4.         pickle.dump(inTree, fw)  
  5.         fw.close()  
  6.   
  7. def grabTree(filename):  
  8.         import pickle  
  9.         fr = open(filename)  
  10.         return pickle.load(fr)  
你可以如下測試 : 
>>> tree.storeTree(myTree, 'classifierStoreage.bin')
>>> tree.grabTree('classifierStoreage.bin')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

相關代碼可以在 這裡下載 (Ch03).

沒有留言:

張貼留言

網誌存檔

關於我自己

我的相片
Where there is a will, there is a way!