Preface
scikit-learn 是最受歡迎的 Python 的機器學習庫本章我們將使用 scikit-learn 調用 Keras 生成的模型. 本章將:
Introduction
Keras 在深度學習很受歡迎,但是只能做深度學習:Keras 是最小化的深度學習庫,目標在於快速搭建深度學習模型基於 SciPy 的 scikit-learn,數值運算效率很高,適用於普遍的機器學習任務,提供很多機器學習工具,包括但不限於:
Keras 為 scikit-learn 了封裝 KerasClassifier:
使用 Cross Validation 驗證深度學習模型
Keras 的 KerasClassifier 與 KerasRegressor 兩個類接受 build_fn 參數,傳入函數用以建立模型; 接著我們加入 epochs=150 與 batch_size=10 這兩個參數:這兩個參數會傳入模型的 fit()。方法我們用 scikit-learn 的 StratifiedKFold 類進行 10-Fold Cross validation,測試模型在未知數據的性能,使用並 cross_val_score() 函數檢測模型,打印結果:
- train.py
- #!/usr/bin/env python
- # MLP for Pima Indians Dataset with 10-fold cross validation via sklearn
- from keras.models import Sequential
- from keras.layers import Dense
- from keras.wrappers.scikit_learn import KerasClassifier
- from sklearn.model_selection import StratifiedKFold
- from sklearn.model_selection import cross_val_score
- import numpy
- import pandas
- # Function to create model, required for KerasClassifier
- def create_model():
- # create model
- model = Sequential()
- model.add(Dense(12, input_dim=8, init='uniform', activation='relu'))
- model.add(Dense(8, init='uniform', activation='relu'))
- model.add(Dense(1, init='uniform', activation='sigmoid'))
- # Compile model
- model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
- return model
- # fix random seed for reproducibility
- seed = 7
- numpy.random.seed(seed)
- # load pima indians dataset
- dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
- # split into input (X) and output (Y) variables
- X = dataset[:,0:8]
- Y = dataset[:,8]
- # create model
- model = KerasClassifier(build_fn=create_model, epochs=150, batch_size=10)
- # evaluate using 10-fold cross validation
- kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
- results = cross_val_score(model, X, Y, cv=kfold)
- print(results.mean())
每輪訓練會輸出一次結果,加上最終的平均性能:
比起手工測試,使用 scikit-learn 容易的多。
使用 GridSearch 調整深度學習模型的參數
使用 scikit-learn 封裝 Keras 的模型十分簡單, 進一步想:我們可以給 fit() 方法傳入參數,KerasClassifier 的 build_fn 方法也可以傳入參數可以利用這點進一步調整模型。我們可以用 GridSearch 測試不同參數的性能:create_model() 函數可以傳入optimizer:init 參數,雖然都有默認值, 但我們可以用不同的優化算法和初始權優化網絡. 具體來說, 我們希望搜索:
所有的參數組成一個字典,傳入 scikit-learn 的 GridSearchCV 類:GridSearchCV 類 會對每組參數(2×3×3×3)進行訓練,進行 3-Fold Cross validation. 這樣做的計算量巨大! 耗時巨長如果模型小還可以取一部分數據試試看,因為這裡的 數據集 與 網絡 都不大(1000 個 數據內,9個參數)所以可以在可接受時間下輸出最好的參數和模型,以及平均值:
- train_gs.py
- #!/usr/bin/env python
- # MLP for Pima Indians Dataset with grid search via sklearn
- from keras.models import Sequential
- from keras.layers import Dense
- from keras.wrappers.scikit_learn import KerasClassifier
- from sklearn.model_selection import GridSearchCV
- import numpy
- import pandas
- # Function to create model, required for KerasClassifier
- def create_model(optimizer='rmsprop', init='glorot_uniform'):
- # create model
- model = Sequential()
- model.add(Dense(12, input_dim=8, init=init, activation='relu'))
- model.add(Dense(8, init=init, activation='relu'))
- model.add(Dense(1, init=init, activation='sigmoid'))
- # Compile model
- model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
- return model
- # fix random seed for reproducibility
- seed = 7
- numpy.random.seed(seed)
- # load pima indians dataset
- dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
- # split into input (X) and output (Y) variables
- X = dataset[:,0:8]
- Y = dataset[:,8]
- # create model
- model = KerasClassifier(build_fn=create_model)
- # grid search epochs, batch size and optimizer
- optimizers = ['rmsprop', 'adam']
- init = ['glorot_uniform', 'normal', 'uniform']
- epochs = numpy.array([50, 100, 150])
- batches = numpy.array([5, 10, 20])
- param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, init=init)
- grid = GridSearchCV(estimator=model, param_grid=param_grid)
- grid_result = grid.fit(X, Y)
- # summarize results
- print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
- best_model = grid.best_estimator_
- print('Best model={}'.format(best_model.__class__))
Supplement
* Keras Doc - Wrappers for the Scikit-Learn API
* Ch10 - 多類花朵分類 (iris dataset)
* ML CheatSheet - Optimizer
沒有留言:
張貼留言