数字分类
[1]:
from sklearn_genetic import GASearchCV
from sklearn_genetic.space import Categorical, Integer, Continuous
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
from sklearn_genetic.callbacks import DeltaThreshold, TimerStopping
导入数据并将其分割为训练集和测试集
[2]:
data = load_digits()
label_names = data["target_names"]
y = data["target"]
X = data["data"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
定义需要调优的分类器及其参数网格
[3]:
clf = DecisionTreeClassifier()
params_grid = {
"min_weight_fraction_leaf": Continuous(0, 0.5),
"criterion": Categorical(["gini", "entropy"]),
"max_depth": Integer(2, 20),
"max_leaf_nodes": Integer(2, 30),
}
创建交叉验证策略并可选添加回调函数
[4]:
cv = StratifiedKFold(n_splits=3, shuffle=True)
delta_callback = DeltaThreshold(threshold=0.001, metric="fitness")
timer_callback = TimerStopping(total_seconds=60)
callbacks = [delta_callback, timer_callback]
定义GASearchCV选项
[5]:
evolved_estimator = GASearchCV(
clf,
cv=cv,
scoring="accuracy",
population_size=16,
generations=30,
crossover_probability=0.9,
mutation_probability=0.05,
param_grid=params_grid,
algorithm="eaSimple",
n_jobs=-1,
verbose=True)
拟合模型并查看一些结果
[6]:
evolved_estimator.fit(X_train, y_train, callbacks=callbacks)
y_predict_ga = evolved_estimator.predict(X_test)
accuracy = accuracy_score(y_test, y_predict_ga)
gen nevals fitness fitness_std fitness_max fitness_min
0 16 0.363259 0.136399 0.639235 0.189526
1 14 0.450592 0.119266 0.620116 0.27847
2 12 0.54707 0.1376 0.75478 0.26517
3 12 0.625052 0.113433 0.768911 0.346633
4 16 0.667654 0.11493 0.755611 0.400665
5 14 0.727504 0.0156019 0.759767 0.689111
6 16 0.71462 0.0486477 0.758105 0.607648
7 14 0.701164 0.132646 0.764755 0.190357
8 12 0.735661 0.0115332 0.758936 0.715711
9 16 0.735141 0.00947264 0.748961 0.704073
INFO: DeltaThreshold callback met its criteria
INFO: Stopping the algorithm
[7]:
print(evolved_estimator.best_params_)
print("accuracy score: ", "{:.2f}".format(accuracy))
{'min_weight_fraction_leaf': 0.027793264515431237, 'criterion': 'entropy', 'max_depth': 17, 'max_leaf_nodes': 26}
accuracy score: 0.77