## 2012年6月18日 星期一

### [ MLP ] Machine learning tools and techniques : Ch04 Algorithms ( The Basic Methods ) - Part4

4.7 Instance-based learning

- Finding nearest neighbors efficiently

- kD-tree

Figure 4.12 A kD-tree for four training instances: (a) the tree and (b) instances and splits.

Figure 4.13 Using a kD-tree to find the nearest neighbor of the star.

- ball tree

Choose the point in the ball that is farthest from its center, and then a second point that is farthest from the first one. Assign all data points in the ball to the closest one of these two cluster centers, then compute the centroid of each cluster and the minimum radius required for it to enclose all the data points it represents. This method has the merit that the cost of splitting a ball containing n points is only linear in n. There are more elaborate algorithms that produce tighter balls, but they require more computation.

1. 計算 (training instance) set 的 centroid 並找出 set 中離該 centroid 最遠的一點. 稱為 p1
2. 找出 set 中離第一步計算出來的點最遠的點. 稱為 p2
3. 將 set 其它點與 step1, step2 計算出來的兩點進行 clustering (計算每一點離p1, p2 哪點近, 如果離 p1 近則將該點與 p1 形成一個新的 set, 否則與 p2 形成一個 set)
4. 將第三步計算出來的 set 再重複 step1,2,3 直到該 set 只剩下一個點即結束.

Figure 4.15 Ruling out an entire ball (gray) based on a target point (star) and its current nearest neighbor.

- Instance.java
1. package john.ibl.inst;
2.
3. import java.util.Set;
4.
5. public interface Instance {
6.     // 計算與 inst 的 distance.
7.     public double distance(Instance inst);
8.
9.     // 計算 set 中所有 instances 形成的 centroid
10.     public Instance centroid(Set insts);
11.
12.     // 計算所有傳入 instances 形成的 centroid
13.     public Instance centroid(Instance ...insts);
14. }

- SimpleInst.java :
1. package john.ibl.inst;
2.
3. import java.util.HashSet;
4. import java.util.Set;
5.
6. public class SimpleInst implements Instance{
7.     public double attr1 = 0;
8.     public double attr2 = 0;
9.
10.     public SimpleInst(){}
11.     public SimpleInst(double at1, double at2){this.attr1 = at1; this.attr2 = at2;}
12.
13.     @Override
14.     public double distance(Instance inst) {
15.         if(inst instanceof SimpleInst)
16.         {
17.             SimpleInst sinst = (SimpleInst)inst;
18.
19.             return Math.sqrt(Math.pow(Math.abs(this.attr1-sinst.attr1), 2)+
20.                              Math.pow(Math.abs(this.attr2-sinst.attr2), 2));
21.         }
22.         return -1;
23.     }
24.
25.     @Override
26.     public Instance centroid(Instance... insts) {
27.         Set set = new HashSet();
30.         return this.centroid(set);
31.     }
32.
33.     @Override
34.     public Instance centroid(Set insts) {
35.         double attr1Sum = 0double attr2Sum = 0int cnt = 0;
36.         for(Instance inst:insts)
37.         {
38.             if(inst instanceof SimpleInst)
39.             {
40.                 SimpleInst sinst = (SimpleInst)inst;
41.                 attr1Sum+=sinst.attr1; attr2Sum+=sinst.attr2;
42.                 cnt++;
43.             }
44.         }
45.         return new SimpleInst(attr1Sum/cnt, attr2Sum/cnt);
46.     }
47.
48.     @Override
49.     public String toString(){
50.         return String.format("Inst{attr1=%.2g ; attr2=%.2g}", attr1, attr2);
51.     }
52.
53.     public static void main(String args[])
54.     {
55.         Instance inst1 = new SimpleInst(12);
56.         Instance inst2 = new SimpleInst(46);
57.         Instance inst3 = new SimpleInst(44);
58.         System.out.printf("\t[Info] Distance between inst1 & inst2 = %g!\n", inst1.distance(inst2));
59.         System.out.printf("\t[Info] Centroid of inst1, inst2, inst3=%s\n", inst1.centroid(inst2, inst3));
60.     }
61. }

- BallTree.java
1. package john.ibl;
2.
3. import java.util.HashSet;
4. import java.util.Set;
5.
6. import john.ibl.inst.Instance;
7. import john.ibl.inst.SimpleInst;
8.
9. public class BallTree {
10.     private Node root = null;  /*root of binary tree*/
11.     public int calcnt = 0;  /*count how many distance calculation being conducted.*/
12.     private boolean bSearchDone = false/*When true, nearest neighbor searching is done.*/
13.     public BallTree(){}
14.
15.     public class Node{
16.         public Node                 ln=null;
17.         public Node                 rn=null;
18.         public Set      instances;
19.         public Instance             bundy = null;
20.         public Instance             cent = null;
21.
22.         public Node(Instance inst){instances = new HashSet(); instances.add(inst);}
23.         public Node(Set insts){this.instances = insts;}
24.         public Node(Instance bdy, Set insts){this.bundy = bdy; this.instances = insts;}
25.
26.         @Override
27.         public String toString()
28.         {
29.             return String.format("Cent=%s(%d) %s", centroid(), instances.size(), bundy==null?"":String.format("with b=%s", bundy));
30.         }
31.
32.         public Instance centroid()
33.         {
34.             if(cent!=nullreturn cent;
35.             if(instances.size()>1)
36.             {
37.                 Instance c =  instances.iterator().next().centroid(instances);
38.                 cent = c;
39.                 return cent;
40.             }
41.             else if(instances.size()==1)
42.             {
43.                 Instance c = instances.iterator().next();
44.                 cent = c;
45.                 return cent;
46.             }
47.             else
48.             {
49.                 return null;
50.             }
51.         }
52.     }
53.
54.     public Node _build(Node node, Instance cent)
55.     {
56.         // 0.) If cent is null, do nothing!
57.         if(cent==null || node.instances.size()<=1return node;
58.
59.         // 1.) Find longest distance instance
60.         Instance ldInst = null;
61.         double dist=-1, tmpDist=-1;
62.         for(Instance inst:node.instances)
63.         {
64.             if(ldInst==null) {
65.                 ldInst = inst;
66.                 dist = inst.distance(cent);
67.                 continue;
68.             } else {
69.                 tmpDist = inst.distance(cent);
70.                 if(tmpDist>dist)
71.                 {
72.                     ldInst = inst;
73.                     dist = tmpDist;
74.                 }
75.             }
76.         }
77.
78.         // 2.) Find secondly longest distance instance
79.         Instance sldInst = null;
80.         dist = -1; tmpDist=-1;
81.         for(Instance inst:node.instances)
82.         {
83.             if(inst==ldInst) continue;
84.             else if(sldInst==null){
85.                 sldInst = inst;
86.                 dist = inst.distance(ldInst);
87.             } else {
88.                 tmpDist = inst.distance(ldInst);
89.                 if(tmpDist>dist)
90.                 {
91.                     sldInst = inst;
92.                     dist = tmpDist;
93.                 }
94.             }
95.         }
96.
97.         // 3.) Distribute all others point to upper two instances.
98.         Set ldSet = new HashSet(); ldSet.add(ldInst);
99.         Set sldSet = new HashSet(); sldSet.add(sldInst);
100.         dist = -1; tmpDist=-1;
101.         for(Instance inst:node.instances)
102.         {
103.             if(inst==ldInst || inst==sldInst) continue;
104.             dist = inst.distance(ldInst); tmpDist = inst.distance(sldInst);
105.             if(dist>=tmpDist){
107.             } else {
109.             }
110.         }
111.
112.         Node rn = new Node(ldInst, ldSet); rn = _build(rn, rn.centroid());
113.         Node ln = new Node(sldInst, sldSet); ln = _build(ln, ln.centroid());
114.         node.rn = rn; node.ln = ln;
115.
116.         return node;
117.     }
118.
119.     public boolean build(Set insts)
120.     {
121.         if(insts.size()==0return false;
122.         else if(insts.size()==1)
123.         {
124.             root = new Node(insts.iterator().next());
125.             return true;
126.         }
127.         else
128.         {
129.             root = new Node(insts);
130.             _build(root, root.centroid());
131.             return true;
132.         }
133.     }
134.
135.     public void _showTree(Node node)
136.     {
137.         System.out.printf("\t[Info] Node=%s(%d):\n", node, node.instances.size());
138.         System.out.printf("\t\tLeft node : %s\n", node.ln);
139.         if(node.ln!=nullfor(Instance inst:node.ln.instances) System.out.printf("\t\t\t%s\n", inst);
140.         System.out.printf("\t\tRight node : %s\n", node.rn);
141.         if(node.rn!=nullfor(Instance inst:node.rn.instances) System.out.printf("\t\t\t%s\n", inst);
142.         if(node.ln!=null) _showTree(node.ln);
143.         if(node.rn!=null) _showTree(node.rn);
144.     }
145.
146.     public void showTree()
147.     {
148.         _showTree(root);
149.     }
150.
151.     public Instance _getNearestInst(Node node, Instance inst, Instance cur, double pccd)
152.     {
153.         if(bSearchDone) return cur;
154.         if(node.ln!=null && node.rn!=null)
155.         {
156.             double ccd = 0;
157.             Instance cc = node.centroid();
158.             if(pccd>0)
159.             {
160.                 ccd = pccd; calcnt+=2;
161.             }
162.             else
163.             {
164.                 ccd = cc.distance(inst);
165.                 calcnt+=3;
166.             }
167.
168.             Instance lc = node.ln.centroid(); double lcd = lc.distance(inst);
169.             Instance rc = node.rn.centroid(); double rcd = rc.distance(inst);
170.             if(ccd <= lcd && ccd <=rcd)
171.             {
172.                 for(Instance tinst:node.instances)
173.                 {
174.                     if(cur==null) {cur = tinst; continue;}
175.                     calcnt++;
176.                     if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
177.                 }
178.                 bSearchDone = true;
179.                 return cur;
180.             }
181.             else if(lcd < rcd)
182.             {
186.                     cur =  _getNearestInst(node.ln, inst, cur, lcd);
187.                     bSearchDone = true;
188.                     return cur;
189.                 } else {
190.                     for(Instance tinst:node.instances)
191.                     {
192.                         if(cur==null) {cur = tinst; continue;}
193.                         calcnt++;
194.                         if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
195.                     }
196.                     bSearchDone = true;
197.                     return cur;
198.                 }
199.             }
200.             else
201.             {
205.                     //System.out.printf("\t[Test] Bingo!\n");
206.                     cur =  _getNearestInst(node.rn, inst, cur, rcd);
207.                     bSearchDone = true;
208.                     return cur;
209.                 } else {
210.                     for(Instance tinst:node.instances)
211.                     {
212.                         if(cur==null) {cur = tinst; continue;}
213.                         calcnt++;
214.                         if(tinst.distance(inst) < cur.distance(inst)) cur = tinst;
215.                     }
216.                     bSearchDone = true;
217.                     return cur;
218.                 }
219.             }
220.         }
221.         else if(node.ln!=null)
222.         {
223.             return _getNearestInst(node.ln, inst, cur, -1);
224.         }
225.         else if(node.rn!=null)
226.         {
227.             return _getNearestInst(node.rn, inst, cur, -1);
228.         }
229.         else
230.         {
231.             double dist = (cur==null?0:inst.distance(cur));
232.             for(Instance ti:node.instances)
233.             {
234.                 calcnt++;  // Take one calc cnt
235.                 if(cur==null)
236.                 {
237.                     cur = ti;
238.                     dist = inst.distance(cur);
239.                 }
240.                 else if(inst.distance(ti) < dist)
241.                 {
242.                     cur = ti;
243.                     dist = inst.distance(cur);
244.                 }
245.             }
246.             return cur;
247.         }
248.     }
249.
250.     public Instance getNearestInst(Instance inst)
251.     {
252.         bSearchDone = false;
253.         return _getNearestInst(root, inst, null, -1);
254.     }
255.
256.     public static void main(String args[])
257.     {
258.         Set instSet = new HashSet();
268.
269.         System.out.printf("\t[Info] Total %d points!\n", instSet.size());
270.         BallTree ballTree = new BallTree();
271.         ballTree.build(instSet);
272.         ballTree.showTree();
273.         Instance tart = new SimpleInst(-4.5, -1);
274.         System.out.printf("\t[Info] Nearest to target=%s is %s!\n", tart, ballTree.getNearestInst(tart));
275.         System.out.printf("\t[Info] Total distance cal cnt=%d!\n", ballTree.calcnt);
276.     }
277. }

This message was edited 16 times. Last update was at 19/06/2012 10:52:53

#### 4 則留言:

1. SimpleInst.java -->public Instance centroid(Set insts) -->for(Instance inst: insts )
can not convert to set , why??

1. 作者已經移除這則留言。

2. 因為 HTML 語法的關係, 某些 Java Generic 的語法被拿掉 (因為角括號), 我完整代碼可以在下面連結下載: