定义自定义回调函数
在机器学习训练过程中,许多环节都可能出错(如驱动程序版本不匹配、指标陷入平台期等),这些情况会导致GPU资源和时间被白白浪费。Aim的回调API允许定义自定义回调函数,可在训练流程任意节点执行——通过编程方式为机器学习训练保驾护航,避免资源浪费。
回调实际上可以包含任何可编程功能,例如 记录消息和发送通知,或 在满足给定条件时终止训练过程。
回调函数
条款:
callback - 一个Python函数,用于实现在训练过程中特定时间点执行的自定义逻辑。
回调类 - 用于分组回调函数的python类。可用于在不同回调之间共享状态(如下示例)。
event - 表示要绑定到训练中的事件。
回调API:
TrainingFlow - 定义训练流程/事件。
events.on.* - 用于定义回调函数执行时机的装饰器。
当前可用的训练事件列表:
events.on.training_started - 在训练开始后调用。
events.on.training_successfully_finished - 在训练成功完成后调用,这意味着没有抛出意外异常,即使是手动键盘中断(ctrl+C)。请注意,程序化的提前停止被视为成功完成。
events.on.training_metrics_collected - 在训练指标计算完成并准备记录时调用。通常每N个批次调用一次。
events.on.validation_metrics_collected - 在验证指标计算完成并准备记录时调用。通常在验证循环结束后仅调用一次。
events.on.init - 在回调类初始化后自动调用,且在所有其他事件之前触发。禁止手动调用。典型用例包括为回调函数初始化共享状态(如下示例)。
示例
以下示例演示了如何实现自定义回调来检查和通知,当:
安装了错误的驱动程序版本。
gnorm指标爆炸性增长。
模型开始过拟合。
定义回调函数
from aim.sdk.callbacks import events
class MyCallbacks:
@events.on.init # Called when initializing the TrainingFlow object
def init_gnorm_accumulators(self, **kwargs):
# Initialize a state to collect gnorm values over training process
self.gnorm_sum = 0
self.gnorm_len = 0
@events.on.init
def init_ppl_accumulators(self, **kwargs):
# Initialize a state to collect ppl values over training process
self.ppl_sum = 0
self.ppl_len = 0
@events.on.init
def init_metrics_accumulators(self, **kwargs):
# Collect only the last 100 appended values
self.last_train_metrics = deque(maxlen=100)
# NOTE: all the above methods can be merged into one,
# but are separated for readability reasons
@events.on.training_started
def check_cuda_version(self, run: aim.Run, **kwargs):
if run['__system_params', 'cuda_version'] != '11.6':
run.log_warning("Wrong CUDA version is installed!")
@events.on.training_metrics_collected
def check_gnorm_and_notify(
self,
metrics: Dict[str, Any],
step: int,
# always denotes the number of *training* steps
# `1 step per 4 batches` can be in case of gradient accumulation
epoch: int,
run: aim.Run,
**kwargs
):
current = metrics['gnorm'] # notice that it's the last one
# thus we need to use self.* to collect gnorm values
self.gnorm_sum += current
self.gnorm_len += 1
mean = self.gnorm_sum / self.gnorm_len
if current > 1.15 * mean:
run.log_warning(f'gnorms have exploded. mean: {mean}, '
'step {step}, epoch {epoch} ...')
@events.on.training_metrics_collected
def check_ppl_and_notify(
self,
metrics: Dict[str, Any],
step: int,
epoch: int,
run: aim.Run,
**kwargs
):
current = metrics['ppl'] # notice that it's the last one
# thus we need to use self.* to collect ppl values
self.ppl_sum += current
self.ppl_len += 1
mean = self.ppl_sum / self.ppl_len
if current > 1.15 * mean:
run.log_warning(f'ppl have exploded. mean: {mean}, '
'step {step}, epoch {epoch} ...')
@events.on.training_metrics_collected
def store_last_train_metrics(
self,
metrics: Dict[str, Any],
step: int,
epoch: int,
**kwargs,
):
self.last_train_metrics.append(metrics)
@events.on.validation_metrics_collected
def check_overfitting(
self,
metrics: Dict[str, Any],
epoch: int = None,
run: aim.Run,
**kwargs,
):
mean_train_ppl = sum(
metrics['ppl'] for metrics
in self.last_train_metrics
) / len(self.last_train_metrics)
if mean_train_ppl > 1.15 * metrics['ppl']:
run.log_warning(f'I think we are overfitting on epoch={epoch}')
注册回调函数
from aim import TrainingFlow, Run
aim_run = Run()
training_flow = TrainingFlow(run=aim_run, callbacks=[MyCallbacks()])
# or
training_flow = TrainingFlow(run=aim_run)
training_flow.register(MyCallbacks())