定义自定义回调函数

在机器学习训练过程中,许多环节都可能出错(如驱动程序版本不匹配、指标陷入平台期等),这些情况会导致GPU资源和时间被白白浪费。Aim的回调API允许定义自定义回调函数,可在训练流程任意节点执行——通过编程方式为机器学习训练保驾护航,避免资源浪费。

回调实际上可以包含任何可编程功能,例如 记录消息和发送通知,或 在满足给定条件时终止训练过程。

回调函数

条款:

  • callback - 一个Python函数,用于实现在训练过程中特定时间点执行的自定义逻辑。

  • 回调类 - 用于分组回调函数的python类。可用于在不同回调之间共享状态(如下示例)。

  • event - 表示要绑定到训练中的事件。

回调API:

  • TrainingFlow - 定义训练流程/事件。

  • events.on.* - 用于定义回调函数执行时机的装饰器。

当前可用的训练事件列表:

  • events.on.training_started - 在训练开始后调用。

  • events.on.training_successfully_finished - 在训练成功完成后调用,这意味着没有抛出意外异常,即使是手动键盘中断(ctrl+C)。请注意,程序化的提前停止被视为成功完成。

  • events.on.training_metrics_collected - 在训练指标计算完成并准备记录时调用。通常每N个批次调用一次。

  • events.on.validation_metrics_collected - 在验证指标计算完成并准备记录时调用。通常在验证循环结束后仅调用一次。

  • events.on.init - 在回调类初始化后自动调用,且在所有其他事件之前触发。禁止手动调用。典型用例包括为回调函数初始化共享状态(如下示例)。

示例

以下示例演示了如何实现自定义回调来检查和通知,当:

  • 安装了错误的驱动程序版本。

  • gnorm指标爆炸性增长。

  • 模型开始过拟合。

定义回调函数

from aim.sdk.callbacks import events

class MyCallbacks:
    @events.on.init  # Called when initializing the TrainingFlow object
    def init_gnorm_accumulators(self, **kwargs):
        # Initialize a state to collect gnorm values over training process
        self.gnorm_sum = 0
        self.gnorm_len = 0

    @events.on.init
    def init_ppl_accumulators(self, **kwargs):
        # Initialize a state to collect ppl values over training process
        self.ppl_sum = 0
        self.ppl_len = 0

    @events.on.init
    def init_metrics_accumulators(self, **kwargs):
        # Collect only the last 100 appended values
        self.last_train_metrics = deque(maxlen=100)

    # NOTE: all the above methods can be merged into one,
    #        but are separated for readability reasons

    @events.on.training_started
    def check_cuda_version(self, run: aim.Run, **kwargs):
        if run['__system_params', 'cuda_version'] != '11.6':
            run.log_warning("Wrong CUDA version is installed!")

    @events.on.training_metrics_collected
    def check_gnorm_and_notify(
        self,
        metrics: Dict[str, Any],
        step: int,
        # always denotes the number of *training* steps
        # `1 step per 4 batches` can be in case of gradient accumulation
        epoch: int,
        run: aim.Run,
        **kwargs
    ):
        current = metrics['gnorm'] # notice that it's the last one
        # thus we need to use self.* to collect gnorm values
        self.gnorm_sum += current
        self.gnorm_len += 1
        mean = self.gnorm_sum / self.gnorm_len

        if current > 1.15 * mean:
            run.log_warning(f'gnorms have exploded. mean: {mean}, '
                             'step {step}, epoch {epoch} ...')

    @events.on.training_metrics_collected
    def check_ppl_and_notify(
        self,
        metrics: Dict[str, Any],
        step: int,
        epoch: int,
        run: aim.Run,
        **kwargs
    ):
        current = metrics['ppl'] # notice that it's the last one
        # thus we need to use self.* to collect ppl values
        self.ppl_sum += current
        self.ppl_len += 1
        mean = self.ppl_sum / self.ppl_len

        if current > 1.15 * mean:
            run.log_warning(f'ppl have exploded. mean: {mean}, '
                             'step {step}, epoch {epoch} ...')

    @events.on.training_metrics_collected
    def store_last_train_metrics(
        self,
        metrics: Dict[str, Any],
        step: int,
        epoch: int,
        **kwargs,
    ):
        self.last_train_metrics.append(metrics)

    @events.on.validation_metrics_collected
    def check_overfitting(
        self,
        metrics: Dict[str, Any],
        epoch: int = None,
        run: aim.Run,
        **kwargs,
    ):
        mean_train_ppl = sum(
            metrics['ppl'] for metrics
            in self.last_train_metrics
        ) / len(self.last_train_metrics)

        if mean_train_ppl > 1.15 * metrics['ppl']:
            run.log_warning(f'I think we are overfitting on epoch={epoch}')

注册回调函数

from aim import TrainingFlow, Run

aim_run = Run()

training_flow = TrainingFlow(run=aim_run, callbacks=[MyCallbacks()])
# or
training_flow = TrainingFlow(run=aim_run)
training_flow.register(MyCallbacks())