程式扎記: [ MLP ] Machine learning tools and techniques : Ch04 Algorithms ( The Basic Methods ) - Part4

標籤

2012年6月18日 星期一

[ MLP ] Machine learning tools and techniques : Ch04 Algorithms ( The Basic Methods ) - Part4


4.7 Instance-based learning
這一節在討論 Instance-based learning, 並提出了 k-D tree, Ball tree algorithm 讓我們可以快速計算出 target 與 training set 中最近的距離是哪點, 並透過該點決定 target 的 class. 通常如果我們定義的 Instance 有多個 attributes : a1, a2... an. 則我們的 distance function 最簡單的定義會是 Euclidean distance :


- Finding nearest neighbors efficiently
既然我們的 training instance set 已經 ready, 當有 target instance 進來要求與 training instance set 中最近的一點, 最直覺的方法就是將 training instance set 中的每一點透過 distance function 計算與 target instance 的距離並找出最近的一點. 但是這樣的時間複雜度會是 N (假設 training instance set 的大小為 N). 因此當 training instance set 的大小大達一個程度後, finding nearest neighbors 的計算將會成為一個 bottle neck! 因此我們需要一個有效的方法來找出 target instance 的 nearest neighbors.

- kD-tree
有一種資料結構叫做 kD-tree, 其中 k 指的是 instance 的 attribute 個數. 詳細的部分可以參考書上 161 頁的部分. 基本原理是將 instance 透過 axis 的切割將之以 binary tree 的方式儲存, 底下為一簡單範例 (attribute 個數=k=2):

Figure 4.12 A kD-tree for four training instances: (a) the tree and (b) instances and splits.

這樣我們只要找出 target instance (星狀點) 落於哪個平面, 再以它與該平面的最近一點為半徑畫圓 (根據定義, 一個平面頂多只有一點! 這裡是實心點), 這樣我們只要計算在圓內每一點的 distance 即可找出 target instance 與 training instance set 最近的距離而不用與 training instance set 的每一點都計算距離 :

Figure 4.13 Using a kD-tree to find the nearest neighbor of the star.

- ball tree
雖然 kD-tree 的定義很明確, 但是在實作上會有很多問題. 因此有另一個資料結構叫 ball tree, 它一樣是透過切割 training instances 並將之以 binary tree 的方式儲存起來. 底下是書上對如何形成 ball tree 的 algorithm 描述 :
Choose the point in the ball that is farthest from its center, and then a second point that is farthest from the first one. Assign all data points in the ball to the closest one of these two cluster centers, then compute the centroid of each cluster and the minimum radius required for it to enclose all the data points it represents. This method has the merit that the cost of splitting a ball containing n points is only linear in n. There are more elaborate algorithms that produce tighter balls, but they require more computation.

有看沒有懂? 我簡單敘述如下 :
1. 計算 (training instance) set 的 centroid 並找出 set 中離該 centroid 最遠的一點. 稱為 p1
2. 找出 set 中離第一步計算出來的點最遠的點. 稱為 p2
3. 將 set 其它點與 step1, step2 計算出來的兩點進行 clustering (計算每一點離p1, p2 哪點近, 如果離 p1 近則將該點與 p1 形成一個新的 set, 否則與 p2 形成一個 set)
4. 將第三步計算出來的 set 再重複 step1,2,3 直到該 set 只剩下一個點即結束.

透過上面的計算我們可以形成一個 binary tree, 底下為一個範例 :


如果我們便可以找出 target instance 落於哪個圓, 並將 target instance 與圓內最遠一點為半徑畫出一個圓. 並且只需要計算與該圓內的每一點的距離即可 :

Figure 4.15 Ruling out an entire ball (gray) based on a target point (star) and its current nearest neighbor.

這邊的範例代碼我們首先定義 Instance 的介面 :
- Instance.java
  1. package john.ibl.inst;  
  2.   
  3. import java.util.Set;  
  4.   
  5. public interface Instance {  
  6.     // 計算與 inst 的 distance.  
  7.     public double distance(Instance inst);  
  8.       
  9.     // 計算 set 中所有 instances 形成的 centroid  
  10.     public Instance centroid(Set insts);  
  11.       
  12.     // 計算所有傳入 instances 形成的 centroid  
  13.     public Instance centroid(Instance ...insts);  
  14. }  
接著我們實做一個只有兩個 attributes 的 Instance 類別 SimpleInst :
- SimpleInst.java :
  1. package john.ibl.inst;  
  2.   
  3. import java.util.HashSet;  
  4. import java.util.Set;  
  5.   
  6. public class SimpleInst implements Instance{  
  7.     public double attr1 = 0;  
  8.     public double attr2 = 0;  
  9.       
  10.     public SimpleInst(){}  
  11.     public SimpleInst(double at1, double at2){this.attr1 = at1; this.attr2 = at2;}  
  12.   
  13.     @Override  
  14.     public double distance(Instance inst) {  
  15.         if(inst instanceof SimpleInst)  
  16.         {  
  17.             SimpleInst sinst = (SimpleInst)inst;  
  18.               
  19.             return Math.sqrt(Math.pow(Math.abs(this.attr1-sinst.attr1), 2)+  
  20.                              Math.pow(Math.abs(this.attr2-sinst.attr2), 2));  
  21.         }  
  22.         return -1;  
  23.     }  
  24.       
  25.     @Override  
  26.     public Instance centroid(Instance... insts) {  
  27.         Set set = new HashSet();  
  28.         for(Instance inst:insts) set.add(inst);  
  29.         set.add(this);  
  30.         return this.centroid(set);  
  31.     }  
  32.   
  33.     @Override  
  34.     public Instance centroid(Set insts) {         
  35.         double attr1Sum = 0double attr2Sum = 0int cnt = 0;  
  36.         for(Instance inst:insts)  
  37.         {  
  38.             if(inst instanceof SimpleInst)  
  39.             {  
  40.                 SimpleInst sinst = (SimpleInst)inst;  
  41.                 attr1Sum+=sinst.attr1; attr2Sum+=sinst.attr2;  
  42.                 cnt++;  
  43.             }  
  44.         }  
  45.         return new SimpleInst(attr1Sum/cnt, attr2Sum/cnt);  
  46.     }  
  47.       
  48.     @Override  
  49.     public String toString(){  
  50.         return String.format("Inst{attr1=%.2g ; attr2=%.2g}", attr1, attr2);  
  51.     }  
  52.       
  53.     public static void main(String args[])  
  54.     {  
  55.         Instance inst1 = new SimpleInst(12);  
  56.         Instance inst2 = new SimpleInst(46);  
  57.         Instance inst3 = new SimpleInst(44);  
  58.         System.out.printf("\t[Info] Distance between inst1 & inst2 = %g!\n", inst1.distance(inst2));  
  59.         System.out.printf("\t[Info] Centroid of inst1, inst2, inst3=%s\n", inst1.centroid(inst2, inst3));         
  60.     }     
  61. }  
最後就是實作剛剛說明 ball tree algorithm 的類別 BallTree :
- BallTree.java 
  1. package john.ibl;  
  2.   
  3. import java.util.HashSet;  
  4. import java.util.Set;  
  5.   
  6. import john.ibl.inst.Instance;  
  7. import john.ibl.inst.SimpleInst;  
  8.   
  9. public class BallTree {  
  10.     private Node root = null;  /*root of binary tree*/  
  11.     public int calcnt = 0;  /*count how many distance calculation being conducted.*/  
  12.     private boolean bSearchDone = false/*When true, nearest neighbor searching is done.*/  
  13.     public BallTree(){}  
  14.   
  15.     public class Node{  
  16.         public Node                 ln=null;  
  17.         public Node                 rn=null;  
  18.         public Set      instances;  
  19.         public Instance             bundy = null;  
  20.         public Instance             cent = null;  
  21.           
  22.         public Node(Instance inst){instances = new HashSet(); instances.add(inst);}  
  23.         public Node(Set insts){this.instances = insts;}  
  24.         public Node(Instance bdy, Set insts){this.bundy = bdy; this.instances = insts;}  
  25.           
  26.         @Override  
  27.         public String toString()  
  28.         {  
  29.             return String.format("Cent=%s(%d) %s", centroid(), instances.size(), bundy==null?"":String.format("with b=%s", bundy));  
  30.         }  
  31.           
  32.         public Instance centroid()  
  33.         {  
  34.             if(cent!=nullreturn cent;  
  35.             if(instances.size()>1)  
  36.             {  
  37.                 Instance c =  instances.iterator().next().centroid(instances);  
  38.                 cent = c;  
  39.                 return cent;  
  40.             }  
  41.             else if(instances.size()==1)  
  42.             {  
  43.                 Instance c = instances.iterator().next();  
  44.                 cent = c;  
  45.                 return cent;  
  46.             }  
  47.             else  
  48.             {  
  49.                 return null;  
  50.             }             
  51.         }  
  52.     }  
  53.       
  54.     public Node _build(Node node, Instance cent)  
  55.     {  
  56.         // 0.) If cent is null, do nothing!  
  57.         if(cent==null || node.instances.size()<=1return node;  
  58.           
  59.         // 1.) Find longest distance instance  
  60.         Instance ldInst = null;  
  61.         double dist=-1, tmpDist=-1;  
  62.         for(Instance inst:node.instances)  
  63.         {  
  64.             if(ldInst==null) {  
  65.                 ldInst = inst;  
  66.                 dist = inst.distance(cent);  
  67.                 continue;  
  68.             } else {  
  69.                 tmpDist = inst.distance(cent);  
  70.                 if(tmpDist>dist)  
  71.                 {  
  72.                     ldInst = inst;  
  73.                     dist = tmpDist;  
  74.                 }  
  75.             }  
  76.         }  
  77.           
  78.         // 2.) Find secondly longest distance instance  
  79.         Instance sldInst = null;  
  80.         dist = -1; tmpDist=-1;  
  81.         for(Instance inst:node.instances)  
  82.         {  
  83.             if(inst==ldInst) continue;  
  84.             else if(sldInst==null){  
  85.                 sldInst = inst;  
  86.                 dist = inst.distance(ldInst);  
  87.             } else {  
  88.                 tmpDist = inst.distance(ldInst);  
  89.                 if(tmpDist>dist)  
  90.                 {  
  91.                     sldInst = inst;  
  92.                     dist = tmpDist;  
  93.                 }  
  94.             }  
  95.         }  
  96.           
  97.         // 3.) Distribute all others point to upper two instances.  
  98.         Set ldSet = new HashSet(); ldSet.add(ldInst);  
  99.         Set sldSet = new HashSet(); sldSet.add(sldInst);  
  100.         dist = -1; tmpDist=-1;  
  101.         for(Instance inst:node.instances)  
  102.         {  
  103.             if(inst==ldInst || inst==sldInst) continue;  
  104.             dist = inst.distance(ldInst); tmpDist = inst.distance(sldInst);  
  105.             if(dist>=tmpDist){  
  106.                 sldSet.add(inst);  
  107.             } else {  
  108.                 ldSet.add(inst);  
  109.             }  
  110.         }  
  111.           
  112.         Node rn = new Node(ldInst, ldSet); rn = _build(rn, rn.centroid());  
  113.         Node ln = new Node(sldInst, sldSet); ln = _build(ln, ln.centroid());  
  114.         node.rn = rn; node.ln = ln;  
  115.           
  116.         return node;  
  117.     }  
  118.       
  119.     public boolean build(Set insts)  
  120.     {  
  121.         if(insts.size()==0return false;  
  122.         else if(insts.size()==1)  
  123.         {  
  124.             root = new Node(insts.iterator().next());  
  125.             return true;  
  126.         }  
  127.         else  
  128.         {  
  129.             root = new Node(insts);  
  130.             _build(root, root.centroid());  
  131.             return true;  
  132.         }  
  133.     }  
  134.       
  135.     public void _showTree(Node node)  
  136.     {  
  137.         System.out.printf("\t[Info] Node=%s(%d):\n", node, node.instances.size());  
  138.         System.out.printf("\t\tLeft node : %s\n", node.ln);  
  139.         if(node.ln!=nullfor(Instance inst:node.ln.instances) System.out.printf("\t\t\t%s\n", inst);  
  140.         System.out.printf("\t\tRight node : %s\n", node.rn);  
  141.         if(node.rn!=nullfor(Instance inst:node.rn.instances) System.out.printf("\t\t\t%s\n", inst);  
  142.         if(node.ln!=null) _showTree(node.ln);  
  143.         if(node.rn!=null) _showTree(node.rn);  
  144.     }  
  145.       
  146.     public void showTree()  
  147.     {  
  148.         _showTree(root);  
  149.     }  
  150.       
  151.     public Instance _getNearestInst(Node node, Instance inst, Instance cur, double pccd)  
  152.     {  
  153.         if(bSearchDone) return cur;  
  154.         if(node.ln!=null && node.rn!=null)  
  155.         {             
  156.             double ccd = 0;  
  157.             Instance cc = node.centroid();   
  158.             if(pccd>0)  
  159.             {  
  160.                 ccd = pccd; calcnt+=2;  
  161.             }  
  162.             else  
  163.             {  
  164.                 ccd = cc.distance(inst);  
  165.                 calcnt+=3;  
  166.             }  
  167.               
  168.             Instance lc = node.ln.centroid(); double lcd = lc.distance(inst);  
  169.             Instance rc = node.rn.centroid(); double rcd = rc.distance(inst);  
  170.             if(ccd <= lcd && ccd <=rcd)  
  171.             {  
  172.                 for(Instance tinst:node.instances)  
  173.                 {  
  174.                     if(cur==null) {cur = tinst; continue;}  
  175.                     calcnt++;  
  176.                     if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;  
  177.                 }  
  178.                 bSearchDone = true;  
  179.                 return cur;  
  180.             }  
  181.             else if(lcd < rcd)  
  182.             {  
  183.                 double radli = node.ln.centroid().distance(node.ln.bundy);  
  184.                 double radri = node.rn.centroid().distance(node.rn.bundy);  
  185.                 if(lcd<=radli && rcd > radri) {                     
  186.                     cur =  _getNearestInst(node.ln, inst, cur, lcd);  
  187.                     bSearchDone = true;  
  188.                     return cur;  
  189.                 } else {  
  190.                     for(Instance tinst:node.instances)  
  191.                     {  
  192.                         if(cur==null) {cur = tinst; continue;}  
  193.                         calcnt++;  
  194.                         if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;  
  195.                     }  
  196.                     bSearchDone = true;  
  197.                     return cur;  
  198.                 }  
  199.             }             
  200.             else  
  201.             {  
  202.                 double radli = node.ln.centroid().distance(node.ln.bundy);  
  203.                 double radri = node.rn.centroid().distance(node.rn.bundy);  
  204.                 if(rcd<=radri && lcd > radli) {         
  205.                     //System.out.printf("\t[Test] Bingo!\n");  
  206.                     cur =  _getNearestInst(node.rn, inst, cur, rcd);  
  207.                     bSearchDone = true;  
  208.                     return cur;  
  209.                 } else {  
  210.                     for(Instance tinst:node.instances)  
  211.                     {  
  212.                         if(cur==null) {cur = tinst; continue;}  
  213.                         calcnt++;  
  214.                         if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;  
  215.                     }  
  216.                     bSearchDone = true;  
  217.                     return cur;  
  218.                 }  
  219.             }  
  220.         }  
  221.         else if(node.ln!=null)  
  222.         {  
  223.             return _getNearestInst(node.ln, inst, cur, -1);  
  224.         }  
  225.         else if(node.rn!=null)  
  226.         {  
  227.             return _getNearestInst(node.rn, inst, cur, -1);  
  228.         }  
  229.         else  
  230.         {  
  231.             double dist = (cur==null?0:inst.distance(cur));  
  232.             for(Instance ti:node.instances)  
  233.             {  
  234.                 calcnt++;  // Take one calc cnt  
  235.                 if(cur==null)   
  236.                 {  
  237.                     cur = ti;  
  238.                     dist = inst.distance(cur);  
  239.                 }  
  240.                 else if(inst.distance(ti) < dist)  
  241.                 {  
  242.                     cur = ti;  
  243.                     dist = inst.distance(cur);  
  244.                 }  
  245.             }  
  246.             return cur;  
  247.         }  
  248.     }  
  249.       
  250.     public Instance getNearestInst(Instance inst)  
  251.     {         
  252.         bSearchDone = false;  
  253.         return _getNearestInst(root, inst, null, -1);  
  254.     }  
  255.       
  256.     public static void main(String args[])  
  257.     {  
  258.         Set instSet = new HashSet();  
  259.         instSet.add(new SimpleInst(11));  
  260.         instSet.add(new SimpleInst(22));  
  261.         instSet.add(new SimpleInst(23));  
  262.         instSet.add(new SimpleInst(30));  
  263.         instSet.add(new SimpleInst(2, -1));  
  264.         instSet.add(new SimpleInst(-31));  
  265.         instSet.add(new SimpleInst(-20));  
  266.         instSet.add(new SimpleInst(-4, -1));  
  267.         instSet.add(new SimpleInst(-5, -2));  
  268.                   
  269.         System.out.printf("\t[Info] Total %d points!\n", instSet.size());  
  270.         BallTree ballTree = new BallTree();  
  271.         ballTree.build(instSet);  
  272.         ballTree.showTree();  
  273.         Instance tart = new SimpleInst(-4.5, -1);  
  274.         System.out.printf("\t[Info] Nearest to target=%s is %s!\n", tart, ballTree.getNearestInst(tart));  
  275.         System.out.printf("\t[Info] Total distance cal cnt=%d!\n", ballTree.calcnt);  
  276.     }  
  277. }  
底下為 BallTree 的 main 函數中計算的範例示意圖 :
This message was edited 16 times. Last update was at 19/06/2012 10:52:53

4 則留言:

  1. SimpleInst.java -->public Instance centroid(Set insts) -->for(Instance inst: insts )
    can not convert to set , why??

    回覆刪除
    回覆
    1. 作者已經移除這則留言。

      刪除
    2. 因為 HTML 語法的關係, 某些 Java Generic 的語法被拿掉 (因為角括號), 我完整代碼可以在下面連結下載:
      https://drive.google.com/file/d/0B3JEkc9JW7BOS2pHRUZ0ZUcyN0U/view?usp=sharing

      刪除

網誌存檔

關於我自己

我的相片
Where there is a will, there is a way!