在本書的第 4.8 節介紹 Clustering, 它是一種 un-supervised learning 的技術來幫你對 Data set 進行分類, 而被分在同一類的 sub-set 在定義上比不同類的 sub-set 中的 data 較為相似. 因此 "相似" 的定義會決定 Clustering 的分類演算法. 例外一個問題是 Clustering 通常要分幾類必須事先決定, 而這往往會影響 Clustering 的 quality, 先來看看 wiki 定義 :
而更多有關 Clustering analysis 的說明可以參考 wiki 上的 Cluster analysis.
K-Means Clustering :
有關更深入的 Clustering analysis 在本書可以跳到第六章 :
這邊要叫紹的是一個比較經典且從 1967 就被使用至今 - k-means clustering :
至於這個 clustering 的 algorithm 必須先選定 cluster 的個數 (k), 而該演算法的說明可以參考下面示意圖 :
簡單將上面演算法步驟說明如下 :
Simple K-Means Implementation :
針對上面描述的 K-Means, 底下使用 Java 做簡單的實作並看看執行過程了解實際 K-Means 的運作. 這邊定義了介面 Instance 作為我們 training data 的 介面, 而實際實作該介面的類別SimpleInst 即為我們要餵進去 K-Means 的 data :
- Instance.java :
- package john.ibl.inst;
- import java.util.Set;
- public interface Instance {
- // 計算與 inst 的 distance.
- public double distance(Instance inst);
- // 計算 set 中所有 instances 形成的 centroid
- public Instance centroid(Set
insts); - // 計算所有傳入 instances 形成的 centroid
- public Instance centroid(Instance ...insts);
- public boolean isSamePoint(Instance inst);
- }
- package john.ibl.inst;
- import java.util.HashSet;
- import java.util.Set;
- public class SimpleInst implements Instance{
- public double attr1 = 0;
- public double attr2 = 0;
- public SimpleInst(){}
- public SimpleInst(double at1, double at2){this.attr1 = at1; this.attr2 = at2;}
- @Override
- public double distance(Instance inst) {
- if(inst instanceof SimpleInst)
- {
- SimpleInst sinst = (SimpleInst)inst;
- return Math.sqrt(Math.pow(Math.abs(this.attr1-sinst.attr1), 2)+
- Math.pow(Math.abs(this.attr2-sinst.attr2), 2));
- }
- return -1;
- }
- @Override
- public Instance centroid(Instance... insts) {
- Set
set = new HashSet (); - for(Instance inst:insts) set.add(inst);
- set.add(this);
- return this.centroid(set);
- }
- @Override
- public Instance centroid(Set
insts) { - double attr1Sum = 0; double attr2Sum = 0; int cnt = 0;
- for(Instance inst:insts)
- {
- if(inst instanceof SimpleInst)
- {
- SimpleInst sinst = (SimpleInst)inst;
- attr1Sum+=sinst.attr1; attr2Sum+=sinst.attr2;
- cnt++;
- }
- }
- return new SimpleInst(attr1Sum/cnt, attr2Sum/cnt);
- }
- @Override
- public String toString(){
- return String.format("Inst{attr1=%.2g ; attr2=%.2g}", attr1, attr2);
- }
- public static void main(String args[])
- {
- Instance inst1 = new SimpleInst(1, 2);
- Instance inst2 = new SimpleInst(4, 6);
- Instance inst3 = new SimpleInst(4, 4);
- System.out.printf("\t[Info] Distance between inst1 & inst2 = %g!\n", inst1.distance(inst2));
- System.out.printf("\t[Info] Centroid of inst1, inst2, inst3=%s\n", inst1.centroid(inst2, inst3));
- }
- @Override
- public boolean isSamePoint(Instance inst) {
- if(inst instanceof SimpleInst)
- {
- SimpleInst tInst = (SimpleInst)inst;
- if(tInst.attr1==this.attr1 && tInst.attr2==this.attr2) return true;
- }
- return false;
- }
- }
* public void processUntilCentrStable(int max_loop, Set
Argument :
完整代碼如下 :
- KMeans.java :
- package john.cluster.hierarchical;
- import java.util.HashSet;
- import java.util.LinkedList;
- import java.util.List;
- import java.util.Set;
- import john.ibl.inst.Instance;
- import john.ibl.inst.SimpleInst;
- public class KMeans {
- public int k = 2;
- public List
clusters = null; - public KMeans(int k, Instance ...insts){
- if(k!=insts.length)
- {
- System.out.printf("\t[Error] K(%s) doesn't match input instances(%d)!\n", k, insts.length);
- return;
- }
- this.k = k;
- clusters = new LinkedList
(); - for(int i=0; i
- {
- clusters.add(new Cluster(insts[i]));
- }
- }
- public void processUntilCentrStable(int max_loop, Set
instGrp) - {
- int loop = 0;
- while(true)
- {
- loop++;
- for(Instance inst:instGrp)
- {
- int cp = 0; double dist = -1;
- for(int j=0; j
- {
- Cluster c = clusters.get(j);
- double tdist = c.dist(inst);
- if(dist==-1)
- {
- dist = c.dist(inst);
- cp = j;
- }
- else if(tdist < dist)
- {
- cp = j;
- dist = tdist;
- }
- }
- clusters.get(cp).addInst(inst);
- }
- boolean isStable = true;
- int cnt=0;
- for(int i=0; i
- {
- Cluster c = clusters.get(i);
- System.out.printf("\t[Info] Centr=%s :\n", c.centr);
- for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);
- if(!c.recalcCentr(false))
- {
- isStable = false;
- cnt=i;
- break;
- }
- }
- if(isStable) break;
- else
- {
- for(int i=0; i
- {
- if(i<=cnt) clusters.get(i).resetGroup();
- else
- {
- Cluster c = clusters.get(i);
- System.out.printf("\t[Info] Centr=%s :\n", c.centr);
- for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);
- c.recalcCentr(true);
- }
- }
- }
- System.out.printf("===========================================================\n");
- if(loop==max_loop) break;
- }
- System.out.printf("\t[Info] Done! Total loop=%d\n", loop);
- }
- public void process(int loop, Set
instGrp) - {
- for(int i=1; i<=loop; i++)
- {
- for(Instance inst:instGrp)
- {
- int cp = 0; double dist = -1;
- for(int j=0; j
- {
- Cluster c = clusters.get(j);
- double tdist = c.dist(inst);
- if(dist==-1)
- {
- dist = c.dist(inst);
- cp = j;
- }
- else if(tdist < dist)
- {
- cp = j;
- dist = tdist;
- }
- }
- clusters.get(cp).addInst(inst);
- }
- if(i!=loop) {
- for(Cluster c:clusters)
- {
- System.out.printf("\t[Info] Centr=%s :\n", c.centr);
- for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);
- c.recalcCentr(true);
- }
- System.out.printf("===========================================================\n");
- } else {
- for(Cluster c:clusters)
- {
- System.out.printf("\t[Info] Centr=%s :\n", c.centr);
- for(Instance inst:c.groups) System.out.printf("\t\t%s\n", inst);
- //c.recalcCentr(true);
- }
- }
- }
- }
- public class Cluster
- {
- public Instance centr = null;
- public Set
groups = null; - public Cluster(Instance initCentroid)
- {
- this.centr = initCentroid;
- groups = new HashSet
(); - }
- public double dist(Instance inst){
- return centr.distance(inst);
- }
- public void resetGroup(){groups.clear();}
- public boolean recalcCentr(boolean isReset)
- {
- Instance uCentr = centr.centroid(groups);
- boolean rst = uCentr.isSamePoint(centr);
- centr = uCentr;
- if(isReset) groups.clear();
- return rst;
- }
- public void addInst(Instance inst){groups.add(inst);}
- //public void addInst(int attr1, int attr2){groups.add(new SimpleInst(attr1, attr2));}
- }
- public static void main(String args[])
- {
- Set
instGrp = new HashSet (); - Instance c1 = new SimpleInst(-2, 0.5);
- Instance c2 = new SimpleInst(-4, -2.5);
- Instance c3 = new SimpleInst(-0.5, -1);
- instGrp.add(new SimpleInst(2, -1.5));
- instGrp.add(new SimpleInst(-3, -2));
- instGrp.add(new SimpleInst(-3, 1));
- instGrp.add(new SimpleInst(-2, -1.5));
- instGrp.add(new SimpleInst(-4, -1));
- instGrp.add(new SimpleInst(2.5, 0));
- instGrp.add(new SimpleInst(1, 0.5));
- instGrp.add(new SimpleInst(3, -3));
- instGrp.add(new SimpleInst(-4, 1));
- instGrp.add(new SimpleInst(-3, 2));
- instGrp.add(new SimpleInst(-3.5, -3));
- //instGrp.add(new SimpleInst(-2, -1.5));
- KMeans km = new KMeans(3, c1, c2, c3);
- km.processUntilCentrStable(10, instGrp); /* 1) Process with constant looping*/
- }
- }
補充說明 :
其上面的實作是非常沒有效率的, 等於每個 data 都需要與每個 Cluster 去計算 distance/similarity, 如果 |Dataset| = n, |Clusters| =m, 那一個 loop 便需要 m*n 次的時間複雜度! 事實上有更有效率的做法, 便是將 Dataset 轉成資料結構 Ball tree, 而這相關的說明可以參考 這裡.
沒有留言:
張貼留言