2012年9月7日 星期五

[ klibsvm ] libsvm 的 Java wrapper

Preface : 
最近因為研究需要, 打算使用 林智仁 老師 開發的 libsvm 來訓練自己的 corpus 抽出取來的 features, 作為實驗數據的 baseline. 從網站上面可以下載 zip 檔後解壓縮, 可以發現有個 Java 目錄. 耶! 對於 Java 比較熟的我是一大福音. 有興趣可以去讀裡面元代碼的 "svm_train.java", "svm_predict.java" 與 "svm.java" 就可以知道其命令列的使用方法與工作流程. 

基本上使用 "svm_train.java", "svm_predict.java" 並可以進行 SVM 的 training 與 predict 這兩個動作便可以滿足大多數的需求, 不過因為那是命令列的用法因此如果你希望把它當作 library 去使用, 可以需要做些封裝. 這邊我根據自己的使用經驗根據 libsvm.jar 的類別設計, 自己封裝了 "training" 與 "predict" 的用法, 並另外導出一個 klibsvm.jar 方便自己後續寫程式來 Integrate libsvm.jar 來做 automation. 

Prepare training data: 
SVM 是一個可以支援很高維的 machine learning, 而在 training data 你必須事先定義好 features (feature 的大小決定 training 的維度) ; 而每個 feature 會有自己用來 training 的值, 另外 SVM 是用來解 Classification 的問題, 因此你還需要提供一個 Label 的值 (Classification 的類別). 接著這些資料將會被存成 Vector 送進去的 SVM 進行 training. 在 libsvm 的 README 中的 "Installation and Data Format" 有定義要為進去的資料格式如下 : 
<label> <index1>:<value1> <index2>:<value2> ...

其中 <label>  就是類別的種類, 使用數字 1, 2 etc 代表; 而 <index1> 則是說明是哪一個 feature, 而 <value1> 則是對應該 feature 的值. 

以等下我要說明的範例, 假設我的 training data 是一個 X/Y 二維平面的座標, 我定義了一個方程式 0.7*X^2 - 10 = Y ; 如果我有一堆座標(x,y), 我定義 x 帶入先前方程式得到的值如果大於等於座標的y 值我定義為類別1, 反之為類別2. 因此這邊我可以假設我有兩個 features1->x, feature2->y ; 接著考慮有座標 (-10,48), 因為 0.7*(-10)^2 - 10 = 60 < 48 -> 得到類別2, 因此我有一筆 Training 的 record 為 : 
2 1:-10.000 2:48.00

當然你的 training data 不可能只有一筆 record. 接著我們將之 mapping 到 klibsvm ; 我們的一筆 training 紀錄會使用類別 ksvm.data.Record 來代表, 而該類別上的屬性 label 則代表個該筆紀錄預測的類別. 而每一個 training 紀錄可能有多個 feature, 這邊使用類別 ksvm.data.TData 來代表一個 feature. 你可以使用 Record.addFeature() 來添加 feature 到你的 training 紀錄中. 而TData 上面的屬性 index 代表著第幾個 feature ; value 則是該 feature 的值. 因此如果我們要將剛剛的 training 座標 (-10, 48) 傳換成代碼, 可以參考下面 : 
  1. Record record = new Record();  
  2. record.addFeature(1, -10); // 添加 feature x=-10  
  3. record.addFeature(248); // 添加 feature y=48  
  4. record.label=2// 設定預測類別2  
但這樣一筆筆寫到代碼我可能會瘋掉 ><". 因此通常我們會利用外部的檔案來存放 training 的 data, 再用程式來載入. 在 klibsvm 中 Training 吃的類別必須實作介面 ksvm.data.IRecordIter : 
  1. package ksvm.data;  
  2.   
  3. import java.util.Iterator;  
  4.   
  5. public interface IRecordIter extends Iterator{}  
而如果外部檔案的格式滿足 libsvm 的格式定義, 則可以使用類別 ksvm.data.BasicRDIter 來載入外部 training data, 它實作了介面 ksvm.data.IRecordIter : 
  1. package ksvm.data;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.StringTokenizer;  
  8.   
  9. public class BasicRDIter implements IRecordIter{  
  10.     private BufferedReader  br = null;  
  11.     private String          nextProcline = null;      
  12.   
  13.     public BasicRDIter(File trainFile)throws IOException{this(new BufferedReader(new FileReader(trainFile)));}  
  14.     public BasicRDIter(BufferedReader br){this.br = br; retriveProcline();}  
  15.       
  16.     protected void retriveProcline()  
  17.     {  
  18.         try  
  19.         {  
  20.             do  
  21.             {  
  22.                 nextProcline = br.readLine();  
  23.                 if(nextProcline!=null &&   
  24.                    (nextProcline.isEmpty() || nextProcline.startsWith("#"))) continue;  
  25.                 else break;  
  26.             }while(nextProcline!=null);  
  27.             if(nextProcline==null) br.close();  
  28.         }  
  29.         catch(IOException e)  
  30.         {  
  31.             e.printStackTrace();  
  32.             nextProcline = null;  
  33.         }  
  34.     }  
  35.       
  36.     @Override  
  37.     public boolean hasNext() {  
  38.         return (nextProcline!=null);  
  39.     }  
  40.   
  41.     @Override  
  42.     public Record next() {  
  43.         if(nextProcline!=null)  
  44.         {  
  45.             //System.out.printf("\t[Test] line=%s\n", nextProcline);  
  46.             StringTokenizer st = new StringTokenizer(nextProcline," \t\n\r\f:");  
  47.             Record rd = new Record(st);  
  48.             retriveProcline();  
  49.             return rd;  
  50.         }  
  51.         return null;  
  52.     }  
  53.   
  54.     @Override  
  55.     public void remove() {  
  56.         throw new java.lang.UnsupportedOperationException("Not support");         
  57.     }  
  58. }  
因此考慮我們 training data 外部檔案 "scatters_train.tf", 可以如下載入提供後續 training model 使用 : 
  1. File trainFile = new File("scatters_train.tf");  
  2. BasicRDIter basicRDIter = new BasicRDIter(trainFile);   // 1) Prepare training input data iter  
Training process : 
現在知道什麼是 training 紀錄 ; 什麼是 features 與 如何從外部載入 training data. 剩下的就是 training 與 predict. 在 training 的部分簡單到不行, 建立類別 ksvm.run.SVMTrain 物件後再將剛剛載入的 training data 傳入其建構子便完成 training 的準備階段. 完整 Training process 代碼如下 : 
  1. public static void main(String[] args) throws IOException{  
  2.     File trainFile = new File("scatters_train.tf");  
  3.     File modelFile = new File("scatters.model");  
  4.     BasicRDIter basicRDIter = new BasicRDIter(trainFile);   // 1) Prepare training input data iter  
  5.     SVMTrain train = new SVMTrain(basicRDIter);             // 2) Prepare SVMTrain object  
  6.     //train.param.C = 10;  
  7.     if(train.start())                                       // 3) Start training  
  8.     {  
  9.         System.out.printf("\t[Info] Training is done!\n");  
  10.         train.saveModel(modelFile);                         //4) Output training model to external file  
  11.     }  
  12.     else  
  13.     {  
  14.         System.out.printf("\t[Info] Something wrong while training:\n");  
  15.         for(String em:train.errMsg)  
  16.         {  
  17.             System.out.printf("\t%s\n", em);  
  18.         }  
  19.         return;  
  20.     }  
  21.         
  22. }  
執行後會出現如下 training 訊息並導出 training model 到外部檔案 "scatters.model" : 
..*
optimization finished, #iter = 777
nu = 0.52894127152566
obj = -105.0870659825975, rho = -0.3120789555948129
nSV = 340, nBSV = 45
Total nSV = 340
[Info] Training time=0 sec
[Info] Training is done!

Predict process : 
有了 Training model, 後續的應用便是利用它來對給定的 feature set 進行預測並推論每個紀錄應該是屬於哪一個類別. 接著我們可以透過類別 ksvm.demo.ui.ScatterPlotDemo 將我們剛剛的 training data 用視覺化的效果標示於座標軸上 : 
  1. ScatterPlotDemo demo = new ScatterPlotDemo(new File("scatters_train.tf")); // 載入 training data 並繪於座標上  
  2. demo.pack();  
  3. demo.setVisible(true);  
執行後會出現下圖, 紅點即是類別2 ; 藍點是類別1 ; 綠點則是由剛剛我們定義的方程式繪出 : 
 

接著有了剛剛我們訓練出來的 model (scatters.model) , 接著我們使用類別 ksvm.run.SVMPredict 對測試 data (scatters_test.tf) 進行 predict. 參考範例代碼如下 : 
  1. File testFile = new File("scatters_test.tf");  
  2. File modelFile = new File("scatters.model"); // Training model file  
  3. File resultFile = new File("scatters_test.pid"); // Output predict result file  
  4. SVMPredict svmPredict = new SVMPredict(modelFile);  
  5. svmPredict.start(new BasicRDIter(testFile), resultFile);  // Start predicting  
執行後產生訊息如下, 並導出 predict 結果到 "scatters_test.pid". 由訊息可以知道準確率約 92% : 
[Info] Accuracy = 91.56%(347/379) (classification)

你也可以透過下面的代碼, 將測試的結果視覺化到座標上面 : 
  1. File modelFile = new File("scatters.model");  
  2. File testFile = new File("scatters_test.tf");  
  3. SVMPredict svmPredict = new SVMPredict(modelFile);  
  4. ScatterPlotDemo demo = new ScatterPlotDemo(testFile, svmPredict);  
  5. demo.pack();  
  6. demo.setVisible(true);  
執行後會由如下 UI 產出, 黃點 代表的是預測錯誤的部分 : 
 

Supplement : 
[ libsvm ] 碼上會!Java+libSVM 分析動態資料 (144行) 
[ libsvm ] piaip 的 (lib)SVM 簡易入門

21 則留言:

  1. 大大 我覺得你的文章很好 可以請你再把klibsvm.jar 載點從新掛上嗎 他失效了 我也想試試看 謝瞜

    回覆刪除
    回覆
    1. 請試試看下面的 link:
      https://www.space.ntu.edu.tw/navigate/s/CCB6C37C279644BEA9BD90640EC88C6CQQY

      頁面的 link 已經更新, 如果不行再 refresh 一下頁面試試看. ^^

      刪除
  2. 請問可以改參數等等的嗎? g c等等的

    回覆刪除
    回覆
    1. 如果 gc (Garbage Collect?) 指的是 java 的命令參數, 當然可以!
      如果指的是 libsvm 的命令列參數, 則目前支援的有:
      -o: Output file. (SINGLE)
      -k: Kernel type. 0=linear; 1=polynomial; 2=radial basis; 3=sigmoid; 4=precomputed kernel(0~4 default=2) (SINGLE)
      -p: Show predict probability. (SIGN)
      -c: Only for training. Signal for doing cross validation. (SIGN)
      -n: nr_fold for cross validation. (SINGLE)
      -a: Output answer with prediction. (SIGN)
      -s: svm_type : set type of SVM (default 0). 0=C-SVC; 1=nu-SVC; 2=one-class SVM; 3=epsilon-SVR; 4=nu-SVR. (SINGLE)
      -t: Task type. Support 'train/predict'. Default is . (SINGLE)
      -m: In training->Output model path; In prediction->Load in model path. (SINGLE)
      -i: Input file. When in train model, this argument is for corpus file; in predict model, this argument is for test file(s). When multiple files is given, separated them with ':'. (SINGLE)
      --COST: Cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1) (SINGLE)

      刪除
  3. 大大你好,請問"現在知道什麼是 training 紀錄 ; 什麼是 features 與 如何從外部載入 training data. 剩下的就是 training 與 predict. 在 training 的部分簡單到不行, 建立類別 ksvm.run.SVMTrain 物件後再將剛剛載入的 training data 傳入其建構子便完成 training 的準備階段."
    這段我看不太懂 不知道該怎麼操作~請問可以再解釋一次嗎?

    回覆刪除
    回覆
    1. "training 紀錄" 指的就是你的 training input format; " features" 指的就是你有興趣的特徵; "training data" 指的就是一堆 "training 紀錄".
      以這邊的範例來說, 一個 "training 紀錄" 就是:
      2 1:-10.000 2:48.00
      上面說的是 Label/Class=2 中有一筆紀錄是 Feature1=-10, Feature2=48.
      而一堆這樣 "training 紀錄" 的東西就叫做 training data

      刪除
    2. 所以你要使用 SVM 之前要先決定你的 Target 是什麼? 如水果甜不甜, 接著決定你的 Feature, 如硬度, 顏色深淺. 然後準備 training data:
      =====================================
      甜 硬度:1 顏色:2
      不甜 硬度:4 顏色:2
      ...
      =====================================
      因為 SVM 讀不懂中文, 因此你會將中文用數字取代如 "甜"=1, "不甜"=2, 硬度 (Feature1) = 1, 顏色 (Feature2)=2

      FYI

      刪除
  4. 感謝你的回覆,但我要問的不是這個><
    請問可以跟您留下mail我在跟您討論嗎? 感謝你

    回覆刪除
    回覆
    1. 這是我的mail: kevin80388@gmail.com

      刪除
    2. 作者已經移除這則留言。

      刪除
    3. 我 SVMTrain裡面放這個

      File trainFile = new File("scatters_train.tf");
      BasicRDIter basicRDIter = new BasicRDIter(trainFile);

      但是new BasicRDIter(trainFile); 這裡錯誤

      刪除
    4. 我現在最大的問題是不知道哪個類別要放在哪跟著主成程式一起使用 :(

      刪除
    5. 可以貼一下你的錯誤訊息嗎? 另外你的 "scatters_train.tf" 是用我的測試檔案, 還是有改成你自己的訓練資料?
      使用流程大約是:
      1. 使用 ksvm.run.SVMTrain 傳入訓練資料, 並產生訓練模型
      2. 使用 ksvm.run.SVMPredict 載入訓練模型, 並對傳入的測試資料進行預測.

      BTW, 我的 email 是 puremonkey2001@yahoo.com.tw

      刪除
    6. 一個完整的訓練與測試代碼如下:
      File trainFile = new File("scatters_train.tf");
      File modelFile = new File("scatters.model");
      File testFile = new File("scatters_test.tf");


      // Training -> Generate Model
      // 1. Feed in training data
      // 2. Output training model
      BasicRDIter basicRDIter = new BasicRDIter(trainFile); // 1) Prepare training input data iter
      SVMTrain train = new SVMTrain(basicRDIter); // 2) Prepare SVMTrain object
      //train.param.C = 10;
      if(train.start()) // 3) Start training
      {
      System.out.printf("\t[Info] Training is done!\n");
      train.saveModel(modelFile);
      }
      else
      {
      System.out.printf("\t[Info] Something wrong while training:\n");
      for(String em:train.errMsg)
      {
      System.out.printf("\t%s\n", em);
      }
      return;
      }

      // Predicting ->
      // 1. Loading mode
      // 2. Feed in testing data and output prediction result .
      File resultFile = new File("scatters_test.pid");
      SVMPredict svmPredict = new SVMPredict(modelFile);
      svmPredict.start(new BasicRDIter(testFile), resultFile);

      刪除
  5. 感謝你的回覆,我再寄MAIL給你跟您討論。謝謝!!!!

    回覆刪除
  6. 大大 我想問一下為什麼在執行 ScatterPlotDemo demo = new ScatterPlotDemo(new File("scatters_train.tf")); 的時候會
    出現 java.lang.NoClassDefFoundError: org/jfree/chart/ChartPanel 的錯誤啊??

    回覆刪除
    回覆
    1. 該套件有使用到 JFreeChart library (http://www.jfree.org/jfreechart/), 請下載後將 jar 加到執行的 classpath 中.

      刪除
  7. 作者已經移除這則留言。

    回覆刪除
  8. 作者大大你好,有辦法可以擷取出TP FP TN FN 出來嗎? 我想要另外算precision 跟 recall, 但我擷取不出來。 謝謝

    回覆刪除
    回覆
    1. 請參考下面原始碼:
      https://github.com/johnklee/klibsvm/blob/master/src/ksvm/run/SVMPredict.java

      中的 API:predict (line 90)

      刪除
    2. 謝謝,但我最後是直接拿產出的out檔跟原test做比對,也是有一樣的效果。感恩

      刪除

[Git 常見問題] error: The following untracked working tree files would be overwritten by merge

  Source From  Here 方案1: // x -----删除忽略文件已经对 git 来说不识别的文件 // d -----删除未被添加到 git 的路径中的文件 // f -----强制运行 #   git clean -d -fx 方案2: 今天在服务器上  gi...