4.7 Instance-based learning
這一節在討論 Instance-based learning, 並提出了 k-D tree, Ball tree algorithm 讓我們可以快速計算出 target 與 training set 中最近的距離是哪點, 並透過該點決定 target 的 class. 通常如果我們定義的 Instance 有多個 attributes : a1, a2... an. 則我們的 distance function 最簡單的定義會是 Euclidean distance :
- Finding nearest neighbors efficiently
既然我們的 training instance set 已經 ready, 當有 target instance 進來要求與 training instance set 中最近的一點, 最直覺的方法就是將 training instance set 中的每一點透過 distance function 計算與 target instance 的距離並找出最近的一點. 但是這樣的時間複雜度會是 N (假設 training instance set 的大小為 N). 因此當 training instance set 的大小大達一個程度後, finding nearest neighbors 的計算將會成為一個 bottle neck! 因此我們需要一個有效的方法來找出 target instance 的 nearest neighbors.
- kD-tree
有一種資料結構叫做 kD-tree, 其中 k 指的是 instance 的 attribute 個數. 詳細的部分可以參考書上 161 頁的部分. 基本原理是將 instance 透過 axis 的切割將之以 binary tree 的方式儲存, 底下為一簡單範例 (attribute 個數=k=2):
Figure 4.12 A kD-tree for four training instances: (a) the tree and (b) instances and splits.
這樣我們只要找出 target instance (星狀點) 落於哪個平面, 再以它與該平面的最近一點為半徑畫圓 (根據定義, 一個平面頂多只有一點! 這裡是實心點), 這樣我們只要計算在圓內每一點的 distance 即可找出 target instance 與 training instance set 最近的距離而不用與 training instance set 的每一點都計算距離 :
Figure 4.13 Using a kD-tree to find the nearest neighbor of the star.
- ball tree
雖然 kD-tree 的定義很明確, 但是在實作上會有很多問題. 因此有另一個資料結構叫 ball tree, 它一樣是透過切割 training instances 並將之以 binary tree 的方式儲存起來. 底下是書上對如何形成 ball tree 的 algorithm 描述 :
有看沒有懂? 我簡單敘述如下 :
透過上面的計算我們可以形成一個 binary tree, 底下為一個範例 :
如果我們便可以找出 target instance 落於哪個圓, 並將 target instance 與圓內最遠一點為半徑畫出一個圓. 並且只需要計算與該圓內的每一點的距離即可 :
Figure 4.15 Ruling out an entire ball (gray) based on a target point (star) and its current nearest neighbor.
這邊的範例代碼我們首先定義 Instance 的介面 :
- Instance.java
接著我們實做一個只有兩個 attributes 的 Instance 類別 SimpleInst :
- SimpleInst.java :
最後就是實作剛剛說明 ball tree algorithm 的類別 BallTree :
- BallTree.java
底下為 BallTree 的 main 函數中計算的範例示意圖 :
這一節在討論 Instance-based learning, 並提出了 k-D tree, Ball tree algorithm 讓我們可以快速計算出 target 與 training set 中最近的距離是哪點, 並透過該點決定 target 的 class. 通常如果我們定義的 Instance 有多個 attributes : a1, a2... an. 則我們的 distance function 最簡單的定義會是 Euclidean distance :
- Finding nearest neighbors efficiently
既然我們的 training instance set 已經 ready, 當有 target instance 進來要求與 training instance set 中最近的一點, 最直覺的方法就是將 training instance set 中的每一點透過 distance function 計算與 target instance 的距離並找出最近的一點. 但是這樣的時間複雜度會是 N (假設 training instance set 的大小為 N). 因此當 training instance set 的大小大達一個程度後, finding nearest neighbors 的計算將會成為一個 bottle neck! 因此我們需要一個有效的方法來找出 target instance 的 nearest neighbors.
- kD-tree
有一種資料結構叫做 kD-tree, 其中 k 指的是 instance 的 attribute 個數. 詳細的部分可以參考書上 161 頁的部分. 基本原理是將 instance 透過 axis 的切割將之以 binary tree 的方式儲存, 底下為一簡單範例 (attribute 個數=k=2):
Figure 4.12 A kD-tree for four training instances: (a) the tree and (b) instances and splits.
這樣我們只要找出 target instance (星狀點) 落於哪個平面, 再以它與該平面的最近一點為半徑畫圓 (根據定義, 一個平面頂多只有一點! 這裡是實心點), 這樣我們只要計算在圓內每一點的 distance 即可找出 target instance 與 training instance set 最近的距離而不用與 training instance set 的每一點都計算距離 :
Figure 4.13 Using a kD-tree to find the nearest neighbor of the star.
- ball tree
雖然 kD-tree 的定義很明確, 但是在實作上會有很多問題. 因此有另一個資料結構叫 ball tree, 它一樣是透過切割 training instances 並將之以 binary tree 的方式儲存起來. 底下是書上對如何形成 ball tree 的 algorithm 描述 :
有看沒有懂? 我簡單敘述如下 :
透過上面的計算我們可以形成一個 binary tree, 底下為一個範例 :
如果我們便可以找出 target instance 落於哪個圓, 並將 target instance 與圓內最遠一點為半徑畫出一個圓. 並且只需要計算與該圓內的每一點的距離即可 :
Figure 4.15 Ruling out an entire ball (gray) based on a target point (star) and its current nearest neighbor.
這邊的範例代碼我們首先定義 Instance 的介面 :
- Instance.java
- package john.ibl.inst;
- import java.util.Set;
- public interface Instance {
- // 計算與 inst 的 distance.
- public double distance(Instance inst);
- // 計算 set 中所有 instances 形成的 centroid
- public Instance centroid(Set
insts); - // 計算所有傳入 instances 形成的 centroid
- public Instance centroid(Instance ...insts);
- }
- SimpleInst.java :
- package john.ibl.inst;
- import java.util.HashSet;
- import java.util.Set;
- public class SimpleInst implements Instance{
- public double attr1 = 0;
- public double attr2 = 0;
- public SimpleInst(){}
- public SimpleInst(double at1, double at2){this.attr1 = at1; this.attr2 = at2;}
- @Override
- public double distance(Instance inst) {
- if(inst instanceof SimpleInst)
- {
- SimpleInst sinst = (SimpleInst)inst;
- return Math.sqrt(Math.pow(Math.abs(this.attr1-sinst.attr1), 2)+
- Math.pow(Math.abs(this.attr2-sinst.attr2), 2));
- }
- return -1;
- }
- @Override
- public Instance centroid(Instance... insts) {
- Set
set = new HashSet (); - for(Instance inst:insts) set.add(inst);
- set.add(this);
- return this.centroid(set);
- }
- @Override
- public Instance centroid(Set
insts) { - double attr1Sum = 0; double attr2Sum = 0; int cnt = 0;
- for(Instance inst:insts)
- {
- if(inst instanceof SimpleInst)
- {
- SimpleInst sinst = (SimpleInst)inst;
- attr1Sum+=sinst.attr1; attr2Sum+=sinst.attr2;
- cnt++;
- }
- }
- return new SimpleInst(attr1Sum/cnt, attr2Sum/cnt);
- }
- @Override
- public String toString(){
- return String.format("Inst{attr1=%.2g ; attr2=%.2g}", attr1, attr2);
- }
- public static void main(String args[])
- {
- Instance inst1 = new SimpleInst(1, 2);
- Instance inst2 = new SimpleInst(4, 6);
- Instance inst3 = new SimpleInst(4, 4);
- System.out.printf("\t[Info] Distance between inst1 & inst2 = %g!\n", inst1.distance(inst2));
- System.out.printf("\t[Info] Centroid of inst1, inst2, inst3=%s\n", inst1.centroid(inst2, inst3));
- }
- }
- BallTree.java
- package john.ibl;
- import java.util.HashSet;
- import java.util.Set;
- import john.ibl.inst.Instance;
- import john.ibl.inst.SimpleInst;
- public class BallTree {
- private Node root = null; /*root of binary tree*/
- public int calcnt = 0; /*count how many distance calculation being conducted.*/
- private boolean bSearchDone = false; /*When true, nearest neighbor searching is done.*/
- public BallTree(){}
- public class Node{
- public Node ln=null;
- public Node rn=null;
- public Set
instances; - public Instance bundy = null;
- public Instance cent = null;
- public Node(Instance inst){instances = new HashSet
(); instances.add(inst);} - public Node(Set
insts){ this.instances = insts;} - public Node(Instance bdy, Set
insts){ this.bundy = bdy; this.instances = insts;} - @Override
- public String toString()
- {
- return String.format("Cent=%s(%d) %s", centroid(), instances.size(), bundy==null?"":String.format("with b=%s", bundy));
- }
- public Instance centroid()
- {
- if(cent!=null) return cent;
- if(instances.size()>1)
- {
- Instance c = instances.iterator().next().centroid(instances);
- cent = c;
- return cent;
- }
- else if(instances.size()==1)
- {
- Instance c = instances.iterator().next();
- cent = c;
- return cent;
- }
- else
- {
- return null;
- }
- }
- }
- public Node _build(Node node, Instance cent)
- {
- // 0.) If cent is null, do nothing!
- if(cent==null || node.instances.size()<=1) return node;
- // 1.) Find longest distance instance
- Instance ldInst = null;
- double dist=-1, tmpDist=-1;
- for(Instance inst:node.instances)
- {
- if(ldInst==null) {
- ldInst = inst;
- dist = inst.distance(cent);
- continue;
- } else {
- tmpDist = inst.distance(cent);
- if(tmpDist>dist)
- {
- ldInst = inst;
- dist = tmpDist;
- }
- }
- }
- // 2.) Find secondly longest distance instance
- Instance sldInst = null;
- dist = -1; tmpDist=-1;
- for(Instance inst:node.instances)
- {
- if(inst==ldInst) continue;
- else if(sldInst==null){
- sldInst = inst;
- dist = inst.distance(ldInst);
- } else {
- tmpDist = inst.distance(ldInst);
- if(tmpDist>dist)
- {
- sldInst = inst;
- dist = tmpDist;
- }
- }
- }
- // 3.) Distribute all others point to upper two instances.
- Set
ldSet = new HashSet (); ldSet.add(ldInst); - Set
sldSet = new HashSet (); sldSet.add(sldInst); - dist = -1; tmpDist=-1;
- for(Instance inst:node.instances)
- {
- if(inst==ldInst || inst==sldInst) continue;
- dist = inst.distance(ldInst); tmpDist = inst.distance(sldInst);
- if(dist>=tmpDist){
- sldSet.add(inst);
- } else {
- ldSet.add(inst);
- }
- }
- Node rn = new Node(ldInst, ldSet); rn = _build(rn, rn.centroid());
- Node ln = new Node(sldInst, sldSet); ln = _build(ln, ln.centroid());
- node.rn = rn; node.ln = ln;
- return node;
- }
- public boolean build(Set
insts) - {
- if(insts.size()==0) return false;
- else if(insts.size()==1)
- {
- root = new Node(insts.iterator().next());
- return true;
- }
- else
- {
- root = new Node(insts);
- _build(root, root.centroid());
- return true;
- }
- }
- public void _showTree(Node node)
- {
- System.out.printf("\t[Info] Node=%s(%d):\n", node, node.instances.size());
- System.out.printf("\t\tLeft node : %s\n", node.ln);
- if(node.ln!=null) for(Instance inst:node.ln.instances) System.out.printf("\t\t\t%s\n", inst);
- System.out.printf("\t\tRight node : %s\n", node.rn);
- if(node.rn!=null) for(Instance inst:node.rn.instances) System.out.printf("\t\t\t%s\n", inst);
- if(node.ln!=null) _showTree(node.ln);
- if(node.rn!=null) _showTree(node.rn);
- }
- public void showTree()
- {
- _showTree(root);
- }
- public Instance _getNearestInst(Node node, Instance inst, Instance cur, double pccd)
- {
- if(bSearchDone) return cur;
- if(node.ln!=null && node.rn!=null)
- {
- double ccd = 0;
- Instance cc = node.centroid();
- if(pccd>0)
- {
- ccd = pccd; calcnt+=2;
- }
- else
- {
- ccd = cc.distance(inst);
- calcnt+=3;
- }
- Instance lc = node.ln.centroid(); double lcd = lc.distance(inst);
- Instance rc = node.rn.centroid(); double rcd = rc.distance(inst);
- if(ccd <= lcd && ccd <=rcd)
- {
- for(Instance tinst:node.instances)
- {
- if(cur==null) {cur = tinst; continue;}
- calcnt++;
- if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
- }
- bSearchDone = true;
- return cur;
- }
- else if(lcd < rcd)
- {
- double radli = node.ln.centroid().distance(node.ln.bundy);
- double radri = node.rn.centroid().distance(node.rn.bundy);
- if(lcd<=radli && rcd > radri) {
- cur = _getNearestInst(node.ln, inst, cur, lcd);
- bSearchDone = true;
- return cur;
- } else {
- for(Instance tinst:node.instances)
- {
- if(cur==null) {cur = tinst; continue;}
- calcnt++;
- if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
- }
- bSearchDone = true;
- return cur;
- }
- }
- else
- {
- double radli = node.ln.centroid().distance(node.ln.bundy);
- double radri = node.rn.centroid().distance(node.rn.bundy);
- if(rcd<=radri && lcd > radli) {
- //System.out.printf("\t[Test] Bingo!\n");
- cur = _getNearestInst(node.rn, inst, cur, rcd);
- bSearchDone = true;
- return cur;
- } else {
- for(Instance tinst:node.instances)
- {
- if(cur==null) {cur = tinst; continue;}
- calcnt++;
- if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
- }
- bSearchDone = true;
- return cur;
- }
- }
- }
- else if(node.ln!=null)
- {
- return _getNearestInst(node.ln, inst, cur, -1);
- }
- else if(node.rn!=null)
- {
- return _getNearestInst(node.rn, inst, cur, -1);
- }
- else
- {
- double dist = (cur==null?0:inst.distance(cur));
- for(Instance ti:node.instances)
- {
- calcnt++; // Take one calc cnt
- if(cur==null)
- {
- cur = ti;
- dist = inst.distance(cur);
- }
- else if(inst.distance(ti) < dist)
- {
- cur = ti;
- dist = inst.distance(cur);
- }
- }
- return cur;
- }
- }
- public Instance getNearestInst(Instance inst)
- {
- bSearchDone = false;
- return _getNearestInst(root, inst, null, -1);
- }
- public static void main(String args[])
- {
- Set
instSet = new HashSet (); - instSet.add(new SimpleInst(1, 1));
- instSet.add(new SimpleInst(2, 2));
- instSet.add(new SimpleInst(2, 3));
- instSet.add(new SimpleInst(3, 0));
- instSet.add(new SimpleInst(2, -1));
- instSet.add(new SimpleInst(-3, 1));
- instSet.add(new SimpleInst(-2, 0));
- instSet.add(new SimpleInst(-4, -1));
- instSet.add(new SimpleInst(-5, -2));
- System.out.printf("\t[Info] Total %d points!\n", instSet.size());
- BallTree ballTree = new BallTree();
- ballTree.build(instSet);
- ballTree.showTree();
- Instance tart = new SimpleInst(-4.5, -1);
- System.out.printf("\t[Info] Nearest to target=%s is %s!\n", tart, ballTree.getNearestInst(tart));
- System.out.printf("\t[Info] Total distance cal cnt=%d!\n", ballTree.calcnt);
- }
- }
This message was edited 16 times. Last update was at 19/06/2012 10:52:53
SimpleInst.java -->public Instance centroid(Set insts) -->for(Instance inst: insts )
回覆刪除can not convert to set , why??
作者已經移除這則留言。
刪除因為 HTML 語法的關係, 某些 Java Generic 的語法被拿掉 (因為角括號), 我完整代碼可以在下面連結下載:
刪除https://drive.google.com/file/d/0B3JEkc9JW7BOS2pHRUZ0ZUcyN0U/view?usp=sharing
thanks so much~
刪除