程式扎記: [ ML 小學堂 ] Linear classification using the perceptron

標籤

2012年9月17日 星期一

[ ML 小學堂 ] Linear classification using the perceptron

Linear classification using the perceptron : 
Perceptron 透過 Hyperplane 對已經 classified 的 instances 進行切割, 如果 training data 是 linearly separable 的話, 可以透過這個演算法使用公式: 
 
(w0=1, x0 為常數

對給予的 input (x0,x1...xk) 進行分類. 底下我們將實作一個最簡單版本的 Perceptron 來認識該演算法是如何運作的, 首先來看看這個演算法的 Pseudo code : 
Here, a1, a2, . . ., ak are the attribute values, and w0, w1, . . ., wk are the weights that define the hyperplane. We will assume that each training instance i1, i2, . . . is extended by an additional attribute a0 that always has the value 1. This extension, which is called the bias, just means that we don’t have to include an additional constant element in the sum. If the sum is greater than zero, we will predict the first class; otherwise, we will predict the second class.We want to find values for the weights so that the training data is correctly classified by the hyperplane.

而針對上面 Pseudo code 的說明, 你可以想成每次在做 classify 時, 只要判斷錯誤, 因為 Instance 的值我們是不能動的, 所以能動的就是 weighting vector 形成的 Hyperplane. 所以我們就在每次 mis-classify 時動態調整 weighting vector 的值來移動 Hyperplane. 示意圖如下 : 
 

Simplest implementation of Perceptron : 
底下便是實作的部分, 使用的語言是 Java. 沒有太多特別的地方, 就跟 Pseudo code 說的一樣. 這裡我們定義類別 SimpleInst 來承裝 training data : 
  1. package ml.supervised.perceptron;  
  2.   
  3. import java.util.HashMap;  
  4. import java.util.Iterator;  
  5. import java.util.Map.Entry;  
  6.   
  7. public class SimpleInst implements Inst{  
  8.     public static int MAX_FEATURE_SIZE=-1;  
  9.     public int cls = -1;  
  10.     public double lastSum = -1;  
  11.     public HashMap values = new HashMap();   
  12.   
  13.     public SimpleInst(int c, Double...values){  
  14.         this.MAX_FEATURE_SIZE=values.length;  
  15.         this.setValues(c, values);  
  16.     }  
  17.       
  18.     public SimpleInst(int...values){  
  19.         this.MAX_FEATURE_SIZE=values.length-1;  
  20.         cls = values[0];  
  21.         for(int i=1; ithis.values.put(i-1, (double)values[i]);  
  22.     }  
  23.       
  24.     @Override  
  25.     public String toString()  
  26.     {  
  27.         StringBuffer strBuf = new StringBuffer("");  
  28.         Iterator> iter = values.entrySet().iterator();  
  29.         strBuf.append("Inst(");  
  30.         boolean flag = true;  
  31.         Entry ety;  
  32.         while(iter.hasNext())  
  33.         {  
  34.             ety = iter.next();  
  35.             if(flag==true)  
  36.             {  
  37.                 flag = false;  
  38.                 strBuf.append(ety.getValue());  
  39.             }  
  40.             else  
  41.             {  
  42.                 strBuf.append(String.format(", %.01f", ety.getValue()));  
  43.             }  
  44.         }  
  45.         strBuf.append(")");  
  46.         return strBuf.toString();  
  47.     }  
  48.       
  49.     @Override  
  50.     public int classify(IWeight weight) {  
  51.         Object sum = classifyInReal(weight);  
  52.         if(sum!=null)  
  53.         {             
  54.             if((Double)sum>0return 1// Class1  
  55.             else return 0// Class2  
  56.         }  
  57.         return -1;  
  58.     }  
  59.   
  60.     @Override  
  61.     public int size() {  
  62.         return MAX_FEATURE_SIZE;  
  63.     }  
  64.   
  65.     public double[] values()  
  66.     {  
  67.         double vals[] = new double[MAX_FEATURE_SIZE];  
  68.         for(int i=0; i
  69.             Double d  = (Double)valueAt(i);  
  70.             if(d!=null)  
  71.             {  
  72.                 vals[i] = d;  
  73.             }  
  74.             else  
  75.             {  
  76.                 System.out.printf("\t[Error] idx=%d has no value!\n", i);  
  77.             }  
  78.         }  
  79.         return vals;  
  80.     }  
  81.       
  82.     @Override  
  83.     public Object valueAt(int idx) {  
  84.         return values.get(idx);  
  85.     }  
  86.   
  87.     @Override  
  88.     public boolean setValues(int c, Object... inValues) {  
  89.         if(inValues.length == MAX_FEATURE_SIZE)  
  90.         {  
  91.             for(int i=0; i
  92.             this.cls = c;  
  93.             return true;  
  94.         }  
  95.         return false;  
  96.     }  
  97.   
  98.     @Override  
  99.     public boolean isCC(IWeight weight) {  
  100.         return classify(weight)==cls;  
  101.     }  
  102.   
  103.     @Override  
  104.     public int cls() {  
  105.         return cls;  
  106.     }  
  107.   
  108.     @Override  
  109.     public Object classifyInReal(IWeight weight) {  
  110.         if(weight.size()==this.size())  
  111.         {  
  112.             double sum = 0;  
  113.             for(int i=0; i
  114.             {                 
  115.                 sum+=((Double)valueAt(i))*weight.w(i);  
  116.                 //System.out.printf("w%d=%.0f ; x%d=%.0f -> sum=%.0f", i, weight.w(i), i, (Double)valueAt(i), sum);  
  117.             }  
  118.             lastSum = sum;  
  119.             return sum;  
  120.         }  
  121.         else  
  122.         {  
  123.             System.out.printf("\t[Error] Weighting vector(%d) has different size with input data(%d)!\n", weight.size(), size());  
  124.         }  
  125.         return null;  
  126.     }  
  127. }  
在 weighting vector 的部分, 這裡定義了類別 SimpleWht : 
  1. package ml.supervised.perceptron;  
  2.   
  3. import java.util.Comparator;  
  4. import java.util.HashMap;  
  5. import java.util.Iterator;  
  6. import java.util.Map.Entry;  
  7. import java.util.PriorityQueue;  
  8.   
  9. public class SimpleWht implements IWeight{  
  10.     public static int MAX_FEATURE_SIZE=-1;  
  11.     public HashMap weights = new HashMap();   
  12.       
  13.     public static class Bean implements Comparable  
  14.     {  
  15.         int key;  
  16.         double value;  
  17.         public Bean(int k, double v){key = k; value = v;}  
  18.         @Override  
  19.         public int compareTo(Bean o) {  
  20.             if(key > o.key) return 1;  
  21.             else if(key < o.key) return -1;  
  22.             return 0;  
  23.         }  
  24.     }  
  25.       
  26.     public SimpleWht(int size){  
  27.         //System.out.printf("\t[Test] %d\n", weights.size());  
  28.         this.MAX_FEATURE_SIZE=size;  
  29.         zero();  
  30.         //setW(0, 1); // w0=1  
  31.     }  
  32.       
  33.     @Override  
  34.     public double w(int idx) {  
  35.         return weights.get(idx);  
  36.     }  
  37.   
  38.     @Override  
  39.     public double plus(int idx, double val) {  
  40.         return weights.put(idx, weights.get(idx)+val);  
  41.     }  
  42.   
  43.     @Override  
  44.     public double minus(int idx, double val) {  
  45.         return weights.put(idx, weights.get(idx)-val);  
  46.     }  
  47.   
  48.     @Override  
  49.     public double setW(int idx, double val) {  
  50.         weights.put(idx, Double.valueOf(val));  
  51.         return 0;  
  52.     }  
  53.   
  54.     @Override  
  55.     public void zero() {  
  56.         for(int i=0; i0.0);  
  57.     }  
  58.   
  59.     @Override  
  60.     public int size() {  
  61.         return MAX_FEATURE_SIZE;  
  62.     }  
  63.   
  64.     @Override  
  65.     public void plus(double... vals) {  
  66.         for(int i=0; i
  67.     }  
  68.   
  69.     @Override  
  70.     public void minus(double... vals) {  
  71.         for(int i=0; i
  72.     }  
  73.       
  74.     @Override  
  75.     public String toString()  
  76.     {  
  77.         StringBuffer strBuf = new StringBuffer();  
  78.         Iterator> iter = weights.entrySet().iterator();  
  79.         while(iter.hasNext())  
  80.         {  
  81.             Entry ety = iter.next();  
  82.             strBuf.append(String.format("w%d=%.01f ", ety.getKey(), ety.getValue()));  
  83.         }         
  84.         return strBuf.toString();  
  85.     }  
  86. }  
接著下面是上述 Pseudo code 的實現 : 
  1. public static int MAX_LOOP = 20;   
  2.   
  3. /** 
  4. * BD : 根據  Linear function f = a0 * w0 + a1 * w1 + a2 * w2 + a3 * w3 + a4 * w4  
  5. *      決定  instance 屬於類別 class1(1) 或  class2(0). 如果  f>0  則為類別 class1, 否則為 class2.    
  6. * @param args 
  7. */  
  8. public static void main(String[] args) throws Exception{  
  9.     IWeight     ws = new SimpleWht(5);  
  10.       
  11.     /*定義 training instance list*/    
  12.     List instList = new LinkedList();  
  13.     instList.add(new SimpleInst(01210, -1)); // Class0: (1, 2, 1, 0, -1)  
  14.     instList.add(new SimpleInst(01350, -4));   
  15.     instList.add(new SimpleInst(110, -294)); // Class1: (1, 0, -2, 9, 4)  
  16.     instList.add(new SimpleInst(110, -548));  
  17.     instList.add(new SimpleInst(110014));  
  18.     instList.add(new SimpleInst(011300));  
  19.       
  20.     /*Perceptron algorithm*/  
  21.     int cnt=0;  
  22.     boolean isDone;  
  23.     while(cnt
  24.     {  
  25.         isDone = true;  
  26.         for(Inst ist:instList)  
  27.         {                 
  28.             if(!ist.isCC(ws)) // Not classified correctly  
  29.             {  
  30.                 if(ist.cls()>0)  
  31.                 {  
  32.                     /*Instance 為  class1(1) 但被判為 class2(0) 時, 將 instance 的 attribute 加到  weighting.*/  
  33.                     // w.x < 0 -> class2 -> wrong  
  34.                     System.out.printf("\t[Info] %s is misclassified as class2(%.2f)!\n", ist, ((SimpleInst)ist).lastSum);  
  35.                     ws.plus(((SimpleInst)ist).values());  
  36.                     isDone = false;  
  37.                 }  
  38.                 else  
  39.                 {  
  40.                     /*Instance 為  class2(0) 但被判為 class1(1) 時, 將 instance 的 attribute 從  weighting 減掉.*/  
  41.                     // w.x > 0 -> class1 -> wrong  
  42.                     System.out.printf("\t[Info] %s is misclassified as class1(%.2f)!\n", ist, ((SimpleInst)ist).lastSum);  
  43.                     //System.out.printf("\t[Info] %s is classified correctly!\n", ist);  
  44.                     ws.minus(((SimpleInst)ist).values());  
  45.                     isDone = false;  
  46.                 }  
  47.                 System.out.printf("\t[Info] Modify ws -> %s\n", ws);  
  48.             }  
  49.             else  
  50.             {  
  51.                 System.out.printf("\t[Info] %s is classified correctly!\n", ist);  
  52.             }  
  53.         }  
  54.         System.out.printf("===== Round %d done =====\n", cnt+1);  
  55.         if(isDone) break;  
  56.         cnt++;  
  57.     }  
  58.       
  59.     /*Print classified result*/  
  60.     System.out.printf("\t[Info] Weighting vector:\n%s\n", ws);  
  61.       
  62.     for(int i=0; i
  63.     {  
  64.         int clsInPredict = instList.get(i).classify(ws);  
  65.         int groundTruth = instList.get(i).cls();  
  66.         System.out.printf("\t[Info] Inst(%d) is %s as '%s' (Ground trouth is '%s').\n", i,  
  67.                           clsInPredict==groundTruth?"classified":"misclassified",                                                                     
  68.                           cis(clsInPredict),  
  69.                           cis(groundTruth));  
  70.     }  
  71. }  
  72.   
  73. public static String cis(int cls)  
  74. {  
  75.     if(cls > 0return "class1";  
  76.     else return "class2";  
  77. }  
執行結果如下, 很幸運的是我們在 Round2 就找到可以正確的 weighting vector 對所有 training data 進行 classify : 
[Info] Inst(1.0, 2.0, 1.0, 0.0, -1.0) is classified correctly!
[Info] Inst(1.0, 3.0, 5.0, 0.0, -4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -2.0, 9.0, 4.0) is misclassified as class2(0.00)!
[Info] Modify ws -> w0=1.0 w1=0.0 w2=-2.0 w3=9.0 w4=4.0
[Info] Inst(1.0, 0.0, -5.0, 4.0, 8.0) is classified correctly!
[Info] Inst(1.0, 0.0, 0.0, 1.0, 4.0) is classified correctly!
[Info] Inst(1.0, 1.0, 3.0, 0.0, 0.0) is classified correctly!
===== Round 1 done =====
[Info] Inst(1.0, 2.0, 1.0, 0.0, -1.0) is classified correctly!
[Info] Inst(1.0, 3.0, 5.0, 0.0, -4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -2.0, 9.0, 4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -5.0, 4.0, 8.0) is classified correctly!
[Info] Inst(1.0, 0.0, 0.0, 1.0, 4.0) is classified correctly!
[Info] Inst(1.0, 1.0, 3.0, 0.0, 0.0) is classified correctly!
===== Round 2 done =====
[Info] Weighting vector:
w0=1.0 w1=0.0 w2=-2.0 w3=9.0 w4=4.0
[Info] Inst(0) is classified as 'class2' (Ground trouth is 'class2').
[Info] Inst(1) is classified as 'class2' (Ground trouth is 'class2').
[Info] Inst(2) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(3) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(4) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(5) is classified as 'class2' (Ground trouth is 'class2').

Supplement : 
Wiki : Winnow (algorithm) 
The winnow algorithm is a technique from machine learning for learning a linear classifier from labeled examples. It is very similar to the perceptron algorithm. However, the perceptron algorithm uses an additive weight-update scheme, while Winnow uses a multiplicative scheme that allows it to perform much better when many dimensions are irrelevant (hence its name).

[ ML In Action ] Predicting numeric values : regression - Linear regression (1) 
[ ML In Action ] Logistic Regression

沒有留言:

張貼留言

網誌存檔

關於我自己

我的相片
Where there is a will, there is a way!