2017年7月24日 星期一

[ Scikit- learn ] FAQ - Converting LinearSVC's decision function to probabilities

Source From Here 
Question 
I use linear SVM from scikit learn (LinearSVC) for binary classification problem. I understand that LinearSVC can give me the predicted labels, and the decision scores but I wanted probability estimates (confidence in the label). I want to continue using LinearSVC because of speed (as compared to sklearn.svm.SVC with linear kernel) Is it reasonable to use a logistic function to convert the decision scores to probabilities? 
>>> from sklearn.datasets import make_blobs 
>>> import sklearn.svm as svm 
>>> svmLinear = svm.LinearSVC() 
>>> X, y = make_blobs(centers=4, random_state=8) 
>>> y = y % 2 
>>> from sklearn.model_selection import train_test_split 
>>> X_train, X_test, y_train, y_test = train_test_split(X, y) 
>>> svmLinear.fit(X_train, y_train) 
LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True, 
intercept_scaling=1, loss='squared_hinge', max_iter=1000, 
multi_class='ovr', penalty='l2', random_state=None, tol=0.0001, 
verbose=0)
 

>>> predicted_test = svmLinear.predict(X_test) 
>>> predicted_test_scores = svmLinear.decision_function(X_test) 
>>> for i in range(10): 
... print("%d -> %d (%.02f)" % (y_test[i], predicted_test[i], predicted_test_scores[i])) 
... 
0 -> 1 (0.03) 
1 -> 1 (0.90) 
1 -> 0 (-0.32) 
0 -> 0 (-0.18) 
1 -> 1 (0.57) 
1 -> 1 (0.57) 
0 -> 0 (-0.18) 
0 -> 0 (-0.20) 
0 -> 0 (-0.18)
 
0 -> 1 (0.03)


How-To 
scikit-learn provides CalibratedClassifierCV which can be used to solve this problem: it allows to add probability output to LinearSVC or any other classifier which implements decision_function method: 
>>> from sklearn.calibration import CalibratedClassifierCV 
>>> clf = CalibratedClassifierCV(svmLinear) 
>>> clf.fit(X_train, y_train) 
>>> predicted_test_proba = clf.predict_proba(X_test) 
>>> for i in range(10): 
... print("%d -> %d (%s)" % (y_test[i], predicted_test[i], predicted_test_proba[i])) 
... 
0 -> 1 ([ 0.53668502 0.46331498]) 
1 -> 1 ([ 0.30148176 0.69851824]) 
1 -> 0 ([ 0.59661738 0.40338262]) 
0 -> 0 ([ 0.60080255 0.39919745]) 
1 -> 1 ([ 0.36580505 0.63419495]) 
1 -> 1 ([ 0.36516485 0.63483515]) 
0 -> 0 ([ 0.59834482 0.40165518]) 
0 -> 0 ([ 0.54233899 0.45766101]) 
0 -> 0 ([ 0.59769979 0.40230021])
 
0 -> 1 ([ 0.46477692 0.53522308])


沒有留言:

張貼留言

[Git 常見問題] error: The following untracked working tree files would be overwritten by merge

  Source From  Here 方案1: // x -----删除忽略文件已经对 git 来说不识别的文件 // d -----删除未被添加到 git 的路径中的文件 // f -----强制运行 #   git clean -d -fx 方案2: 今天在服务器上  gi...