Linear classification using the perceptron

Linear classification using the perceptron :
Perceptron 透過 Hyperplane 對已經 classified 的 instances 進行切割, 如果 training data 是 linearly separable 的話, 可以透過這個演算法使用公式:

(w0=1, x0 為常數

Here, a1, a2, . . ., ak are the attribute values, and w0, w1, . . ., wk are the weights that define the hyperplane. We will assume that each training instance i1, i2, . . . is extended by an additional attribute a0 that always has the value 1. This extension, which is called the bias, just means that we don’t have to include an additional constant element in the sum. If the sum is greater than zero, we will predict the first class; otherwise, we will predict the second class.We want to find values for the weights so that the training data is correctly classified by the hyperplane.

Simplest implementation of Perceptron :

1. package ml.supervised.perceptron;
2.
3. import java.util.HashMap;
4. import java.util.Iterator;
5. import java.util.Map.Entry;
6.
7. public class SimpleInst implements Inst{
8.     public static int MAX_FEATURE_SIZE=-1;
9.     public int cls = -1;
10.     public double lastSum = -1;
11.     public HashMap values = new HashMap();
12.
13.     public SimpleInst(int c, Double...values){
14.         this.MAX_FEATURE_SIZE=values.length;
15.         this.setValues(c, values);
16.     }
17.
18.     public SimpleInst(int...values){
19.         this.MAX_FEATURE_SIZE=values.length-1;
20.         cls = values[0];
21.         for(int i=1; ithis.values.put(i-1, (double)values[i]);
22.     }
23.
24.     @Override
25.     public String toString()
26.     {
27.         StringBuffer strBuf = new StringBuffer("");
28.         Iterator> iter = values.entrySet().iterator();
29.         strBuf.append("Inst(");
30.         boolean flag = true;
31.         Entry ety;
32.         while(iter.hasNext())
33.         {
34.             ety = iter.next();
35.             if(flag==true)
36.             {
37.                 flag = false;
38.                 strBuf.append(ety.getValue());
39.             }
40.             else
41.             {
42.                 strBuf.append(String.format(", %.01f", ety.getValue()));
43.             }
44.         }
45.         strBuf.append(")");
46.         return strBuf.toString();
47.     }
48.
49.     @Override
50.     public int classify(IWeight weight) {
51.         Object sum = classifyInReal(weight);
52.         if(sum!=null)
53.         {
54.             if((Double)sum>0return 1// Class1
55.             else return 0// Class2
56.         }
57.         return -1;
58.     }
59.
60.     @Override
61.     public int size() {
62.         return MAX_FEATURE_SIZE;
63.     }
64.
65.     public double[] values()
66.     {
67.         double vals[] = new double[MAX_FEATURE_SIZE];
68.         for(int i=0; i
69.             Double d  = (Double)valueAt(i);
70.             if(d!=null)
71.             {
72.                 vals[i] = d;
73.             }
74.             else
75.             {
76.                 System.out.printf("\t[Error] idx=%d has no value!\n", i);
77.             }
78.         }
79.         return vals;
80.     }
81.
82.     @Override
83.     public Object valueAt(int idx) {
84.         return values.get(idx);
85.     }
86.
87.     @Override
88.     public boolean setValues(int c, Object... inValues) {
89.         if(inValues.length == MAX_FEATURE_SIZE)
90.         {
91.             for(int i=0; i
92.             this.cls = c;
93.             return true;
94.         }
95.         return false;
96.     }
97.
98.     @Override
99.     public boolean isCC(IWeight weight) {
100.         return classify(weight)==cls;
101.     }
102.
103.     @Override
104.     public int cls() {
105.         return cls;
106.     }
107.
108.     @Override
109.     public Object classifyInReal(IWeight weight) {
110.         if(weight.size()==this.size())
111.         {
112.             double sum = 0;
113.             for(int i=0; i
114.             {
115.                 sum+=((Double)valueAt(i))*weight.w(i);
116.                 //System.out.printf("w%d=%.0f ; x%d=%.0f -> sum=%.0f", i, weight.w(i), i, (Double)valueAt(i), sum);
117.             }
118.             lastSum = sum;
119.             return sum;
120.         }
121.         else
122.         {
123.             System.out.printf("\t[Error] Weighting vector(%d) has different size with input data(%d)!\n", weight.size(), size());
124.         }
125.         return null;
126.     }
127. }

1. package ml.supervised.perceptron;
2.
3. import java.util.Comparator;
4. import java.util.HashMap;
5. import java.util.Iterator;
6. import java.util.Map.Entry;
7. import java.util.PriorityQueue;
8.
9. public class SimpleWht implements IWeight{
10.     public static int MAX_FEATURE_SIZE=-1;
11.     public HashMap weights = new HashMap();
12.
13.     public static class Bean implements Comparable
14.     {
15.         int key;
16.         double value;
17.         public Bean(int k, double v){key = k; value = v;}
18.         @Override
19.         public int compareTo(Bean o) {
20.             if(key > o.key) return 1;
21.             else if(key < o.key) return -1;
22.             return 0;
23.         }
24.     }
25.
26.     public SimpleWht(int size){
27.         //System.out.printf("\t[Test] %d\n", weights.size());
28.         this.MAX_FEATURE_SIZE=size;
29.         zero();
30.         //setW(0, 1); // w0=1
31.     }
32.
33.     @Override
34.     public double w(int idx) {
35.         return weights.get(idx);
36.     }
37.
38.     @Override
39.     public double plus(int idx, double val) {
40.         return weights.put(idx, weights.get(idx)+val);
41.     }
42.
43.     @Override
44.     public double minus(int idx, double val) {
45.         return weights.put(idx, weights.get(idx)-val);
46.     }
47.
48.     @Override
49.     public double setW(int idx, double val) {
50.         weights.put(idx, Double.valueOf(val));
51.         return 0;
52.     }
53.
54.     @Override
55.     public void zero() {
56.         for(int i=0; i0.0);
57.     }
58.
59.     @Override
60.     public int size() {
61.         return MAX_FEATURE_SIZE;
62.     }
63.
64.     @Override
65.     public void plus(double... vals) {
66.         for(int i=0; i
67.     }
68.
69.     @Override
70.     public void minus(double... vals) {
71.         for(int i=0; i
72.     }
73.
74.     @Override
75.     public String toString()
76.     {
77.         StringBuffer strBuf = new StringBuffer();
78.         Iterator> iter = weights.entrySet().iterator();
79.         while(iter.hasNext())
80.         {
81.             Entry ety = iter.next();
82.             strBuf.append(String.format("w%d=%.01f ", ety.getKey(), ety.getValue()));
83.         }
84.         return strBuf.toString();
85.     }
86. }

1. public static int MAX_LOOP = 20;
2.
3. /**
4. * BD : 根據  Linear function f = a0 * w0 + a1 * w1 + a2 * w2 + a3 * w3 + a4 * w4
5. *      決定  instance 屬於類別 class1(1) 或  class2(0). 如果  f>0  則為類別 class1, 否則為 class2.
6. * @param args
7. */
8. public static void main(String[] args) throws Exception{
9.     IWeight     ws = new SimpleWht(5);
10.
11.     /*定義 training instance list*/
12.     List instList = new LinkedList();
13.     instList.add(new SimpleInst(01210, -1)); // Class0: (1, 2, 1, 0, -1)
15.     instList.add(new SimpleInst(110, -294)); // Class1: (1, 0, -2, 9, 4)
19.
20.     /*Perceptron algorithm*/
21.     int cnt=0;
22.     boolean isDone;
23.     while(cnt
24.     {
25.         isDone = true;
26.         for(Inst ist:instList)
27.         {
28.             if(!ist.isCC(ws)) // Not classified correctly
29.             {
30.                 if(ist.cls()>0)
31.                 {
32.                     /*Instance 為  class1(1) 但被判為 class2(0) 時, 將 instance 的 attribute 加到  weighting.*/
33.                     // w.x < 0 -> class2 -> wrong
34.                     System.out.printf("\t[Info] %s is misclassified as class2(%.2f)!\n", ist, ((SimpleInst)ist).lastSum);
35.                     ws.plus(((SimpleInst)ist).values());
36.                     isDone = false;
37.                 }
38.                 else
39.                 {
40.                     /*Instance 為  class2(0) 但被判為 class1(1) 時, 將 instance 的 attribute 從  weighting 減掉.*/
41.                     // w.x > 0 -> class1 -> wrong
42.                     System.out.printf("\t[Info] %s is misclassified as class1(%.2f)!\n", ist, ((SimpleInst)ist).lastSum);
43.                     //System.out.printf("\t[Info] %s is classified correctly!\n", ist);
44.                     ws.minus(((SimpleInst)ist).values());
45.                     isDone = false;
46.                 }
47.                 System.out.printf("\t[Info] Modify ws -> %s\n", ws);
48.             }
49.             else
50.             {
51.                 System.out.printf("\t[Info] %s is classified correctly!\n", ist);
52.             }
53.         }
54.         System.out.printf("===== Round %d done =====\n", cnt+1);
55.         if(isDone) break;
56.         cnt++;
57.     }
58.
59.     /*Print classified result*/
60.     System.out.printf("\t[Info] Weighting vector:\n%s\n", ws);
61.
62.     for(int i=0; i
63.     {
64.         int clsInPredict = instList.get(i).classify(ws);
65.         int groundTruth = instList.get(i).cls();
66.         System.out.printf("\t[Info] Inst(%d) is %s as '%s' (Ground trouth is '%s').\n", i,
67.                           clsInPredict==groundTruth?"classified":"misclassified",
68.                           cis(clsInPredict),
69.                           cis(groundTruth));
70.     }
71. }
72.
73. public static String cis(int cls)
74. {
75.     if(cls > 0return "class1";
76.     else return "class2";
77. }

[Info] Inst(1.0, 2.0, 1.0, 0.0, -1.0) is classified correctly!
[Info] Inst(1.0, 3.0, 5.0, 0.0, -4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -2.0, 9.0, 4.0) is misclassified as class2(0.00)!
[Info] Modify ws -> w0=1.0 w1=0.0 w2=-2.0 w3=9.0 w4=4.0
[Info] Inst(1.0, 0.0, -5.0, 4.0, 8.0) is classified correctly!
[Info] Inst(1.0, 0.0, 0.0, 1.0, 4.0) is classified correctly!
[Info] Inst(1.0, 1.0, 3.0, 0.0, 0.0) is classified correctly!
===== Round 1 done =====
[Info] Inst(1.0, 2.0, 1.0, 0.0, -1.0) is classified correctly!
[Info] Inst(1.0, 3.0, 5.0, 0.0, -4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -2.0, 9.0, 4.0) is classified correctly!
[Info] Inst(1.0, 0.0, -5.0, 4.0, 8.0) is classified correctly!
[Info] Inst(1.0, 0.0, 0.0, 1.0, 4.0) is classified correctly!
[Info] Inst(1.0, 1.0, 3.0, 0.0, 0.0) is classified correctly!
===== Round 2 done =====
[Info] Weighting vector:
w0=1.0 w1=0.0 w2=-2.0 w3=9.0 w4=4.0
[Info] Inst(0) is classified as 'class2' (Ground trouth is 'class2').
[Info] Inst(1) is classified as 'class2' (Ground trouth is 'class2').
[Info] Inst(2) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(3) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(4) is classified as 'class1' (Ground trouth is 'class1').
[Info] Inst(5) is classified as 'class2' (Ground trouth is 'class2').

Supplement :
