自定义回调函数
sklearn-genetic-opt 提供了一些预定义的回调函数, 但您也可以通过定义具有特定方法的可调用对象来自定义回调。
参数
回调函数必须是一个继承自BaseCallback类的类,并实现以下方法:
on_start: 在开始第0代之前会评估此函数,它应返回 None 或 False。该函数需要参数 estimator。
on_step: 该回调函数在每一代结束时被调用,其返回值必须为布尔类型,
True 表示优化过程应当终止,False 表示优化可以继续。
该函数接收 record、logbook 和 estimator 三个参数。
on_end: 该方法在最后一代结束时或当停止回调满足其条件后被调用。它期望接收参数 logbook 和 estimator,应该返回 None 或 False。
所有这些方法都是可选的,但至少需要定义一个。
示例
在这个示例中,我们将定义一个虚拟回调函数,当适应度值低于阈值的情况超过N次时终止优化过程。
回调函数必须包含三个参数:record、logbook 和 estimator。
它们分别对应一个字典、一个deap的Logbook对象和当前的GASearchCV(或GAFeatureSelectionCV),
其中包含当前迭代的指标、所有历史迭代的指标以及保存在估计器中的所有属性。
因此,要检查日志簿内部,我们可以定义一个这样的函数:
N=4
metric='fitness'
threshold=0.8
def on_step(record, logbook, threshold, estimator=None):
# Not enough data points
if len(logbook) <= N:
return False
# Get the last N metrics
stats = logbook.select(metric)[(-N - 1):]
n_met_condition = [x for x in stats if x < threshold]
if len(n_met_condition) > N:
return True
return False
由于sklearn-genetic-opt要求将所有逻辑封装在单个对象中,我们必须定义一个包含所有这些参数的类,因此可以改写如下:
from sklearn_genetic.callbacks.base import BaseCallback
class DummyThreshold(BaseCallback):
def __init__(self, threshold, N, metric='fitness'):
self.threshold = threshold
self.N = N
self.metric = metric
def on_step(self, record, logbook, estimator=None):
# Not enough data points
if len(logbook) <= self.N:
return False
# Get the last N metrics
stats = logbook.select(self.metric)[(-self.N - 1):]
n_met_condition = [x for x in stats if x < self.threshold]
if len(n_met_condition) > self.N:
return True
return False
现在,让我们扩展它来添加其他方法,仅用于打印一条消息:
from sklearn_genetic.callbacks.base import BaseCallback
class DummyThreshold(BaseCallback):
def __init__(self, threshold, N, metric='fitness'):
self.threshold = threshold
self.N = N
self.metric = metric
def on_start(self, estimator=None):
print("This training is starting!")
def on_step(self, record, logbook, estimator=None):
# Not enough data points
if len(logbook) <= self.N:
return False
# Get the last N metrics
stats = logbook.select(self.metric)[(-self.N - 1):]
n_met_condition = [x for x in stats if x < self.threshold]
if len(n_met_condition) > self.N:
return True
return False
def on_end(self, logbook=None, estimator=None):
print("I'm done with training!")
就是这样,现在你可以初始化 DummyThreshold 并将其传递给 GASearchCV 实例的 fit 方法:
callback = DummyThreshold(threshold=0.85, N=4, metric='fitness')
evolved_estimator.fit(X, y, callbacks=callback)
以下是该回调函数的输出示例:
请注意这里有一条额外的INFO信息,这是所有会停止训练的回调函数的通用提示。