自定义回调函数

sklearn-genetic-opt 提供了一些预定义的回调函数, 但您也可以通过定义具有特定方法的可调用对象来自定义回调。

参数

回调函数必须是一个继承自BaseCallback类的类,并实现以下方法:

on_start: 在开始第0代之前会评估此函数,它应返回 NoneFalse。该函数需要参数 estimator

on_step: 该回调函数在每一代结束时被调用,其返回值必须为布尔类型, True 表示优化过程应当终止,False 表示优化可以继续。 该函数接收 recordlogbookestimator 三个参数。

on_end: 该方法在最后一代结束时或当停止回调满足其条件后被调用。它期望接收参数 logbookestimator,应该返回 NoneFalse

所有这些方法都是可选的,但至少需要定义一个。

示例

在这个示例中,我们将定义一个虚拟回调函数,当适应度值低于阈值的情况超过N次时终止优化过程。

回调函数必须包含三个参数:recordlogbookestimator。 它们分别对应一个字典、一个deap的Logbook对象和当前的GASearchCV(或GAFeatureSelectionCV), 其中包含当前迭代的指标、所有历史迭代的指标以及保存在估计器中的所有属性。

因此,要检查日志簿内部,我们可以定义一个这样的函数:

N=4
metric='fitness'
threshold=0.8

def on_step(record, logbook, threshold, estimator=None):
    # Not enough data points
    if len(logbook) <= N:
        return False
    # Get the last N metrics
    stats = logbook.select(metric)[(-N - 1):]

    n_met_condition = [x for x in stats if x < threshold]

    if len(n_met_condition) > N:
        return True

    return False

由于sklearn-genetic-opt要求将所有逻辑封装在单个对象中,我们必须定义一个包含所有这些参数的类,因此可以改写如下:

from sklearn_genetic.callbacks.base import BaseCallback

class DummyThreshold(BaseCallback):
    def __init__(self, threshold, N, metric='fitness'):
        self.threshold = threshold
        self.N = N
        self.metric = metric

    def on_step(self, record, logbook, estimator=None):
        # Not enough data points
        if len(logbook) <= self.N:
            return False
        # Get the last N metrics
        stats = logbook.select(self.metric)[(-self.N - 1):]

        n_met_condition = [x for x in stats if x < self.threshold]

        if len(n_met_condition) > self.N:
            return True

        return False

现在,让我们扩展它来添加其他方法,仅用于打印一条消息:

from sklearn_genetic.callbacks.base import BaseCallback

class DummyThreshold(BaseCallback):
    def __init__(self, threshold, N, metric='fitness'):
        self.threshold = threshold
        self.N = N
        self.metric = metric

    def on_start(self, estimator=None):
        print("This training is starting!")

    def on_step(self, record, logbook, estimator=None):
        # Not enough data points
        if len(logbook) <= self.N:
            return False
        # Get the last N metrics
        stats = logbook.select(self.metric)[(-self.N - 1):]

        n_met_condition = [x for x in stats if x < self.threshold]

        if len(n_met_condition) > self.N:
            return True

        return False

    def on_end(self, logbook=None, estimator=None):
        print("I'm done with training!")

就是这样,现在你可以初始化 DummyThreshold 并将其传递给 GASearchCV 实例的 fit 方法:

callback = DummyThreshold(threshold=0.85, N=4, metric='fitness')
evolved_estimator.fit(X, y, callbacks=callback)

以下是该回调函数的输出示例:

../_images/custom_callback_dummy_0.png

请注意这里有一条额外的INFO信息,这是所有会停止训练的回调函数的通用提示。