程式扎記: [ MLP ] Ch04 Algorithms - Clustering : K- Means

標籤

2012年6月22日 星期五

[ MLP ] Ch04 Algorithms - Clustering : K- Means

Clustering 介紹 : 
在本書的第 4.8 節介紹 Clustering, 它是一種 un-supervised learning 的技術來幫你對 Data set 進行分類, 而被分在同一類的 sub-set 在定義上比不同類的 sub-set 中的 data 較為相似. 因此 "相似" 的定義會決定 Clustering 的分類演算法. 例外一個問題是 Clustering 通常要分幾類必須事先決定, 而這往往會影響 Clustering 的 quality, 先來看看 wiki 定義 : 
Cluster analysis is the assignment of a set of observations into subsets (called clusters) so that observations in the same cluster are similar in some sense, while observations in different clusters are dissimilar. The variety of clustering techniques make different assumptions on the structure of the data, often defined by some similarity metric and evaluated for example by internal compactness (similarity between members of the same cluster) and separation between different clusters. Other methods are based on estimated density and graph connectivity. Clustering is a method of unsupervised learning, and a common technique for statistical data analysis.

而更多有關 Clustering analysis 的說明可以參考 wiki 上的 Cluster analysis

K-Means Clustering : 
有關更深入的 Clustering analysis 在本書可以跳到第六章 : 
Chapter 6 we examine newer clustering methods that perform incremental and probabilistic clustering.

這邊要叫紹的是一個比較經典且從 1967 就被使用至今 - k-means clustering : 
In data miningk-means clustering is a method of cluster analysis which aims to partition n observations into k clusters in which each observation belongs to the cluster with the nearest mean. This results in a partitioning of the data space into Voronoi cells.

至於這個 clustering 的 algorithm 必須先選定 cluster 的個數 (k), 而該演算法的說明可以參考下面示意圖 : 
 

簡單將上面演算法步驟說明如下 : 
1. 選定 Cluster 數目與設定對應數目的 centroids
2. 將 Dataset 根據定義的 Similarity/Distance 函數找出每個 data 最相似/最近 的 Cluster 並加入該 group.
3. 等全部 Dataset 的 data 都找到對應的 Cluster 後, 根據 Cluster group 中的 data 重新計算 Cluster 的 centroid.
4. 重複 2~3 步, 直到每個 Cluster 的 centroid 或是 group 中的 data 不再變動便終止.

Simple K-Means Implementation : 
針對上面描述的 K-Means, 底下使用 Java 做簡單的實作並看看執行過程了解實際 K-Means 的運作. 這邊定義了介面 Instance 作為我們 training data 的 介面, 而實際實作該介面的類別SimpleInst 即為我們要餵進去 K-Means 的 data : 
- Instance.java : 
  1. package john.ibl.inst;  
  2.   
  3. import java.util.Set;  
  4.   
  5. public interface Instance {  
  6.     // 計算與 inst 的 distance.  
  7.     public double distance(Instance inst);  
  8.       
  9.     // 計算 set 中所有 instances 形成的 centroid  
  10.     public Instance centroid(Set insts);  
  11.       
  12.     // 計算所有傳入 instances 形成的 centroid  
  13.     public Instance centroid(Instance ...insts);  
  14.       
  15.     public boolean isSamePoint(Instance inst);  
  16. }  
- SimpleInst.java : 包含兩個屬性 attr1, attr2 (兩維/two dimension
  1. package john.ibl.inst;  
  2.   
  3. import java.util.HashSet;  
  4. import java.util.Set;  
  5.   
  6. public class SimpleInst implements Instance{  
  7.     public double attr1 = 0;  
  8.     public double attr2 = 0;  
  9.       
  10.     public SimpleInst(){}  
  11.     public SimpleInst(double at1, double at2){this.attr1 = at1; this.attr2 = at2;}  
  12.   
  13.     @Override  
  14.     public double distance(Instance inst) {  
  15.         if(inst instanceof SimpleInst)  
  16.         {  
  17.             SimpleInst sinst = (SimpleInst)inst;  
  18.               
  19.             return Math.sqrt(Math.pow(Math.abs(this.attr1-sinst.attr1), 2)+  
  20.                              Math.pow(Math.abs(this.attr2-sinst.attr2), 2));  
  21.         }  
  22.         return -1;  
  23.     }  
  24.       
  25.     @Override  
  26.     public Instance centroid(Instance... insts) {  
  27.         Set set = new HashSet();  
  28.         for(Instance inst:insts) set.add(inst);  
  29.         set.add(this);  
  30.         return this.centroid(set);  
  31.     }  
  32.   
  33.     @Override  
  34.     public Instance centroid(Set insts) {         
  35.         double attr1Sum = 0double attr2Sum = 0int cnt = 0;  
  36.         for(Instance inst:insts)  
  37.         {  
  38.             if(inst instanceof SimpleInst)  
  39.             {  
  40.                 SimpleInst sinst = (SimpleInst)inst;  
  41.                 attr1Sum+=sinst.attr1; attr2Sum+=sinst.attr2;  
  42.                 cnt++;  
  43.             }  
  44.         }  
  45.         return new SimpleInst(attr1Sum/cnt, attr2Sum/cnt);  
  46.     }  
  47.       
  48.     @Override  
  49.     public String toString(){  
  50.         return String.format("Inst{attr1=%.2g ; attr2=%.2g}", attr1, attr2);  
  51.     }  
  52.       
  53.     public static void main(String args[])  
  54.     {  
  55.         Instance inst1 = new SimpleInst(12);  
  56.         Instance inst2 = new SimpleInst(46);  
  57.         Instance inst3 = new SimpleInst(44);  
  58.         System.out.printf("\t[Info] Distance between inst1 & inst2 = %g!\n", inst1.distance(inst2));  
  59.         System.out.printf("\t[Info] Centroid of inst1, inst2, inst3=%s\n", inst1.centroid(inst2, inst3));         
  60.     }  
  61.     @Override  
  62.     public boolean isSamePoint(Instance inst) {  
  63.         if(inst instanceof SimpleInst)  
  64.         {  
  65.             SimpleInst tInst = (SimpleInst)inst;  
  66.             if(tInst.attr1==this.attr1 && tInst.attr2==this.attr2) return true;  
  67.         }  
  68.         return false;  
  69.     }     
  70.           
  71. }  
接著是實作 K-Means 演算法的類別 KMeans, 你可以在建構子決定要幾個 Cluster 並將其初始的 Centroids 一起傳入. 接著可以呼叫下面方法進行 training : 
public void processUntilCentrStable(int max_loop, Set instGrp) 
Argument : 
- int max_loop : 最大執行 loop 數.
- Set instGrp : training 的 dataset

完整代碼如下 : 
- KMeans.java : 
  1. package john.cluster.hierarchical;  
  2.   
  3. import java.util.HashSet;  
  4. import java.util.LinkedList;  
  5. import java.util.List;  
  6. import java.util.Set;  
  7.   
  8. import john.ibl.inst.Instance;  
  9. import john.ibl.inst.SimpleInst;  
  10.   
  11. public class KMeans {  
  12.     public int k = 2;  
  13.     public List clusters = null;  
  14.       
  15.     public KMeans(int k, Instance ...insts){  
  16.         if(k!=insts.length)  
  17.         {  
  18.             System.out.printf("\t[Error] K(%s) doesn't match input instances(%d)!\n", k, insts.length);  
  19.             return;  
  20.         }  
  21.         this.k = k;  
  22.         clusters = new LinkedList();  
  23.         for(int i=0; i
  24.         {  
  25.             clusters.add(new Cluster(insts[i]));  
  26.         }  
  27.     }  
  28.       
  29.     public void processUntilCentrStable(int max_loop, Set instGrp)  
  30.     {  
  31.         int loop = 0;  
  32.         while(true)  
  33.         {  
  34.             loop++;  
  35.             for(Instance inst:instGrp)  
  36.             {  
  37.                 int cp = 0double dist = -1;  
  38.                 for(int j=0; j
  39.                 {  
  40.                     Cluster c = clusters.get(j);  
  41.                     double tdist = c.dist(inst);  
  42.                     if(dist==-1)  
  43.                     {  
  44.                         dist = c.dist(inst);  
  45.                         cp = j;  
  46.                     }  
  47.                     else if(tdist < dist)  
  48.                     {  
  49.                         cp = j;  
  50.                         dist = tdist;  
  51.                     }  
  52.                 }  
  53.                 clusters.get(cp).addInst(inst);  
  54.             }  
  55.             boolean isStable = true;  
  56.             int cnt=0;  
  57.             for(int i=0; i
  58.             {  
  59.                 Cluster c = clusters.get(i);  
  60.                 System.out.printf("\t[Info] Centr=%s :\n", c.centr);  
  61.                 for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);  
  62.                 if(!c.recalcCentr(false))  
  63.                 {  
  64.                     isStable = false;  
  65.                     cnt=i;  
  66.                     break;  
  67.                 }  
  68.             }  
  69.             if(isStable) break;  
  70.             else  
  71.             {  
  72.                 for(int i=0; i
  73.                 {  
  74.                     if(i<=cnt) clusters.get(i).resetGroup();  
  75.                     else  
  76.                     {  
  77.                         Cluster c = clusters.get(i);  
  78.                         System.out.printf("\t[Info] Centr=%s :\n", c.centr);  
  79.                         for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);  
  80.                         c.recalcCentr(true);  
  81.                     }  
  82.                 }  
  83.             }  
  84.             System.out.printf("===========================================================\n");   
  85.             if(loop==max_loop) break;  
  86.         }  
  87.         System.out.printf("\t[Info] Done! Total loop=%d\n", loop);  
  88.     }  
  89.       
  90.     public void process(int loop, Set instGrp)  
  91.     {  
  92.         for(int i=1; i<=loop; i++)  
  93.         {             
  94.             for(Instance inst:instGrp)  
  95.             {  
  96.                 int cp = 0double dist = -1;  
  97.                 for(int j=0; j
  98.                 {  
  99.                     Cluster c = clusters.get(j);  
  100.                     double tdist = c.dist(inst);  
  101.                     if(dist==-1)  
  102.                     {  
  103.                         dist = c.dist(inst);  
  104.                         cp = j;  
  105.                     }  
  106.                     else if(tdist < dist)  
  107.                     {  
  108.                         cp = j;  
  109.                         dist = tdist;  
  110.                     }  
  111.                 }  
  112.                 clusters.get(cp).addInst(inst);  
  113.             }  
  114.                                       
  115.             if(i!=loop) {                 
  116.                 for(Cluster c:clusters)   
  117.                 {  
  118.                     System.out.printf("\t[Info] Centr=%s :\n", c.centr);  
  119.                     for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);  
  120.                     c.recalcCentr(true);  
  121.                 }  
  122.                 System.out.printf("===========================================================\n");  
  123.             } else {  
  124.                 for(Cluster c:clusters)   
  125.                 {  
  126.                     System.out.printf("\t[Info] Centr=%s :\n", c.centr);  
  127.                     for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);  
  128.                     //c.recalcCentr(true);  
  129.                 }  
  130.             }  
  131.         }  
  132.     }  
  133.       
  134.     public class Cluster  
  135.     {  
  136.         public Instance centr = null;  
  137.         public Set groups = null;  
  138.         public Cluster(Instance initCentroid)  
  139.         {  
  140.             this.centr = initCentroid;  
  141.             groups = new HashSet();  
  142.         }  
  143.           
  144.         public double dist(Instance inst){  
  145.             return centr.distance(inst);  
  146.         }  
  147.           
  148.         public void resetGroup(){groups.clear();}  
  149.         public boolean recalcCentr(boolean isReset)  
  150.         {  
  151.             Instance uCentr = centr.centroid(groups);  
  152.             boolean rst = uCentr.isSamePoint(centr);  
  153.             centr = uCentr;  
  154.             if(isReset) groups.clear();  
  155.             return rst;  
  156.         }  
  157.           
  158.         public void addInst(Instance inst){groups.add(inst);}  
  159.         //public void addInst(int attr1, int attr2){groups.add(new SimpleInst(attr1, attr2));}  
  160.     }  
  161.       
  162.     public static void main(String args[])  
  163.     {  
  164.         Set instGrp = new HashSet();  
  165.         Instance c1 = new SimpleInst(-20.5);   
  166.         Instance c2 = new SimpleInst(-4, -2.5);   
  167.         Instance c3 = new SimpleInst(-0.5, -1);   
  168.           
  169.         instGrp.add(new SimpleInst(2, -1.5));  
  170.         instGrp.add(new SimpleInst(-3, -2));  
  171.         instGrp.add(new SimpleInst(-31));  
  172.         instGrp.add(new SimpleInst(-2, -1.5));  
  173.         instGrp.add(new SimpleInst(-4, -1));  
  174.           
  175.         instGrp.add(new SimpleInst(2.50));  
  176.         instGrp.add(new SimpleInst(10.5));  
  177.         instGrp.add(new SimpleInst(3, -3));  
  178.         instGrp.add(new SimpleInst(-41));  
  179.         instGrp.add(new SimpleInst(-32));       
  180.         instGrp.add(new SimpleInst(-3.5, -3));  
  181.         //instGrp.add(new SimpleInst(-2, -1.5));  
  182.           
  183.         KMeans km = new KMeans(3, c1, c2, c3);  
  184.         km.processUntilCentrStable(10, instGrp);  /* 1) Process with constant looping*/  
  185.     }  
  186. }  
接著當你執行 KMeans 的 main 函數, 便會按照下面是意圖產生 3 個 Clusters 與對應的 Group : 
 
 

補充說明 : 
其上面的實作是非常沒有效率的, 等於每個 data 都需要與每個 Cluster 去計算 distance/similarity, 如果 |Dataset| = n, |Clusters| =m, 那一個 loop 便需要 m*n 次的時間複雜度! 事實上有更有效率的做法, 便是將 Dataset 轉成資料結構 Ball tree, 而這相關的說明可以參考 這裡.

沒有留言:

張貼留言

網誌存檔

關於我自己

我的相片
Where there is a will, there is a way!