在這裡我們將會介紹 decision tree-building 的演算法, 而使用的演算法將會利用 information theory 的原理來決定 decision tree 的分割以得到最佳的結果. 底下是使用 Decision tree 的概述 :
Tree Construction :
在建立 Decision Tree 的過程中, 你將會不斷的對給定的 Data Set 進行 feature 的選用與切割. 假設你選定了某個 Feature, 接著這個 Feature 的 possible values 便成了 Decision Tree 的 Branches, 而整個 Decision Tree 的 nodes 會由 decision blocks (正方形) 與 terminating blocks (橢圓形) 組成. 而 terminating blocks 即為我們最終的分類 (Class). 底下為一 Decision Tree 範例 :
而切割後的 Data Set, 原先被切割的 Feature 已經被固定成某個值, 而切割後的 Set 我們稱為 Sub Data Set, 可以對剩下的 Features 繼續進行切割, 處理程序跟切割前大致不變, 因此我們可以使用 Recursive way 來撰寫這樣的函數 createBranch(), 底下為此函數的 pseudo code :
- Check if every item in the dataset is in the same class:
- If so return the class label
- Else
- find the "best feature" to split the data
- split the dataset
- create a branch node
- for each split
- call createBranch and add the result to the branch node
- return branch node
這邊我們使用的演算法為 ID3, 透過它你將在每次切割 Feature 時選擇有較好的 Information gain 的 feature 進行, 一直到所有的 terminating node 都有一致的類別 ; 如果到最後還是沒有統一的類別, 就只能由多數的類別決定 terminating node 的類別.
在開始處理演算法前, 照慣例先來看看我們處理的 Data set. 參考下表, 共有五筆記錄 ; 並提供兩個 features ; 而類別為 Fish or Not Fist :
- Information gain
在這邊我們使用的演算法會用到 Information theory 來決定每次切割前 Feature 的選定, 而選定的標準為選定的 Feature 有較高的 Information gain. 在 Information theory 中使用 Shannon entropy 或簡稱 entropy 來量測資料集上面帶有的 information. 對於entropy 的詳細定義這邊不會討論, 而只使用定義後的公式 :
由上面的公式可以推論如果 p(xi) = 1, 則 H(X)=0 ; p(xi) = 0, 則 H(X)=0. 而 entropy 被定義為 "The expected value of the information". 而當 p(xi)=0 或 p(xi)=1 你可以解釋成當你已經可以百分百確定某事, 那麼它將不會帶來任何 information, 故 H(X)=0. 因此在 Decision Tree 的每次 Feature 切割後, entropy 的值會不斷的變小 (因為你越來越接近 classified 的結果). 底下為實作計算 entropy 的函數 :
- Listing 3.1 Function to calculate the Shannon entropy of a dataset
- def calcShannonEnt(dataSet):
- """ Function to calculate the Shannon entropy of a dataset"""
- numEntries = len(dataSet)
- labelCounts = {}
- for featVec in dataSet:
- currentLabel = featVec[-1]
- labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
- shannonEnt = 0.0
- for key in labelCounts:
- prob = float(labelCounts[key]) / numEntries
- shannonEnt -= prob * log(prob, 2)
- return shannonEnt
- def createTestDataSet():
- dataSet = [ [1, 1, 'yes'],
- [1, 1, 'yes'],
- [1, 0, 'no'],
- [0, 1, 'no'],
- [0, 1, 'no']]
- labels = ['no surfacing', 'flippers']
- return dataSet, labels
- Splitting the dataset
在處理 dataset 過程中, 我們必需針對 Feature 對 dataset 進行切割 ; 而切割後 sub dataset 的該 Feature 那欄將被移除 :
- Listing 3.2
- def splitDataSet(dataSet, axis, value):
- """ Dataset splitting on a given feature
- - Argument
- * dataSet : 原始 data set
- * axis : 切割 Feature 的 column index.
- * value : 切割 Feature 的某個值."""
- retDataSet = []
- for featVec in dataSet:
- if featVec[axis] == value:
- reducedFeatVec = featVec[:axis]
- reducedFeatVec.extend(featVec[axis+1:])
- retDataSet.append(reducedFeatVec)
- return retDataSet
現在我們已經能切割 data set 與計算 entropy, 而每次要挑選 Feature 來進行切割的條件會選擇切割後有較大的 Information gain. 也就是maximum(IG(features)) ; IG(feature) = ig_orig - ig_new. (ig_orig=某 Feature 切割前的entropy; ig_new=某 Feature 切割後的entropy). 實作代碼如下 :
- Listing 3.3
- def chooseBestFeatureToSplit(dataSet):
- """ Choosing the best feature to split on
- - Argument
- * dataSet: 原始 data set
- - Return
- 返回有最大 Information gain 的 Feature column."""
- numFeatures = len(dataSet[0]) - 1 # 最後一個欄位是 Label
- baseEntropy = calcShannonEnt(dataSet)
- bestInfoGain = 0.0 ; bestFeature = -1
- for i in range(numFeatures):
- featList = [example[i] for example in dataSet]
- uniqueVals = set(featList)
- newEntropy = 0.0
- for value in uniqueVals:
- subDataSet = splitDataSet(dataSet, i, value)
- prob = len(subDataSet) / float(len(dataSet))
- newEntropy += prob * calcShannonEnt(subDataSet)
- infoGain = baseEntropy - newEntropy # Information gain: Split 會讓 data 更 organized, 所以 Entropy 會變小.
- if infoGain > bestInfoGain:
- bestInfoGain = infoGain
- bestFeature = i
- return bestFeature
- Recursively building the tree
如果切割後出現有一致的 Class (Fish or Not Fish), 則該 Branch 的切割算是完成 ; 如果所有的 Feature 都切割完了, 仍然沒有一致的類別, 則只好採多數決的方法. 通常切割到最後, data set 的欄位會只剩下類別標籤, 可以透過下面方法反回多數的類別 :
- def majorityCnt(classList):
- """ Choose majority class and return.
- - Argument
- * classList: 類別的 List.
- - Return
- 返回 List 中出現最多次的類別."""
- classCnt = {}
- for vote in classList:
- classCnt[vote] = classCnt.get(vote, 0) + 1
- sortedClassCnt = sorted(classCnt.iteritems(),
- key = operator.itemgetter(1),
- reverse=True)
- return sortedClassCnt[0][0]
實作該 Recursive 函數代碼如下 :
- Listing 3.4
- def createTree(dataSet, labels):
- """ Tree-building code.
- - Argument
- * dataSet: 原始 data set
- * labels: 對應 class 標籤的文字說明.
- - Return
- 返回建立的 Decision Tree."""
- classList = [example[-1] for example in dataSet]
- if classList.count(classList[0]) == len(classList): # 所有的 class 都是同一類.
- return classList[0]
- if len(dataSet[0]) == 1: # 已經沒有 feature 可以 split 了.
- return majorityCnt(classList)
- bestFeat = chooseBestFeatureToSplit(dataSet)
- bestFeatLabel = labels[bestFeat]
- myTree = {bestFeatLabel:{}}
- del(labels[bestFeat])
- featValues = [example[bestFeat] for example in dataSet]
- uniqueVals = set(featValues)
- for value in uniqueVals:
- subLabels = labels[:]
- myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
- return myTree
Plotting trees in Python with Matplotlib annotations :
雖然 Decision Tree 的資料以經可以成功 train 出來, 不過不夠直覺. 因此我們將透過 Matplotlib 套件幫我們將 Decision Tree 視覺化, 以提高 Human readability. 完整代碼如下 :
- Listing 3.7
- import matplotlib.pyplot as plt
- decisionNode = dict(boxstyle="sawtooth", fc="0.8")
- leafNode = dict(boxstyle="round4", fc="0.8")
- arrow_args = dict(arrowstyle="<-")
- def getNumLeafs(myTree):
- numLeafs = 0
- firstStr = myTree.keys()[0]
- secondDict = myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
- numLeafs += getNumLeafs(secondDict[key])
- else: numLeafs +=1
- return numLeafs
- def getTreeDepth(myTree):
- maxDepth = 0
- firstStr = myTree.keys()[0]
- secondDict = myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
- thisDepth = 1 + getTreeDepth(secondDict[key])
- else: thisDepth = 1
- if thisDepth > maxDepth: maxDepth = thisDepth
- return maxDepth
- def plotNode(nodeTxt, centerPt, parentPt, nodeType):
- createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
- xytext=centerPt, textcoords='axes fraction',
- va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
- def plotMidText(cntrPt, parentPt, txtString):
- xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
- yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
- createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
- def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
- numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
- depth = getTreeDepth(myTree)
- firstStr = myTree.keys()[0] #the text label for this node should be this
- cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
- plotMidText(cntrPt, parentPt, nodeTxt)
- plotNode(firstStr, cntrPt, parentPt, decisionNode)
- secondDict = myTree[firstStr]
- plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
- plotTree(secondDict[key],cntrPt,str(key)) #recursion
- else: #it's a leaf node print the leaf node
- plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
- plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
- plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
- plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
- #if you do get a dictonary you know it's a tree, and the first element will be another dict
- def createPlot(inTree):
- fig = plt.figure(1, facecolor='white')
- fig.clf()
- axprops = dict(xticks=[], yticks=[])
- createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
- #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
- plotTree.totalW = float(getNumLeafs(inTree))
- plotTree.totalD = float(getTreeDepth(inTree))
- plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
- plotTree(inTree, (0.5,1.0), '')
- plt.show()
- def retrieveTree(i):
- listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
- {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
- ]
- return listOfTrees[i]
執行後產生下圖 :
Testing and storing the classifier :
目前 createTree() 已經能幫我們建立 train 完後的 Decision tree, 接著我們要撰寫可以透過傳入的 Decision tree 與 sample , 而返回 classified 後的結果. 代碼如下 :
- Listing 3.8
- def classify(inputTree, featLabels, testVec):
- """ Classification function for an existing decision tree"""
- firstStr = inputTree.keys()[0] # 取出第一個 decision block
- secondDict = inputTree[firstStr] # 取出對應 decision block 的 sub decision tree.
- featIndex = featLabels.index(firstStr) # 取出該 decision block 對應 feature 的欄位.
- for key in secondDict.keys():
- if testVec[featIndex] == key: # 如果 testVec 在該 feature 的值等於 subdecision tree 的 key
- if type(secondDict[key]).__name__ == 'dict': # 如果該 feature value 的 branch 是 tree, 則繼續 classify.
- classLabel = classify(secondDict[key], featLabels, testVec)
- else: return secondDict[key] # 如果該 feature value 的 branch 是 class, 則立即返回.
- return classLabel
- Use: persisting the decision tree
這邊我們希望已經建好的 Decision Tree 離開程式後, 下次能夠重覆使用而不用重新 train. 這個需求可以藉由 Python module pickle 達成. 因此下面我們使用函數 storeTree() 將 decision tree 存成檔案提供下次使用 ; 函數 grabTree() 則用來載入 storeTree() 輸出的檔案 :
- Listing 3.9
- def storeTree(inTree, filename):
- import pickle
- fw = open(filename, 'w')
- pickle.dump(inTree, fw)
- fw.close()
- def grabTree(filename):
- import pickle
- fr = open(filename)
- return pickle.load(fr)
相關代碼可以在 這裡下載 (Ch03).
沒有留言:
張貼留言