训练回调

class TrainingCallback[source]

基础类:object

训练回调的接口。

初始化回调。

属性摘要

loss

通过训练循环访问的损失。

model

通过训练循环访问的模型。

optimizer

优化器,通过训练循环访问。

result_tracker

结果跟踪器,通过训练循环访问。

training_loop

训练循环。

方法总结

on_batch(epoch, batch, batch_loss, **kwargs)

调用训练批次。

post_batch(epoch, batch, **kwargs)

调用训练批次。

post_epoch(epoch, epoch_loss, **kwargs)

在epoch之后调用。

post_train(losses, **kwargs)

训练后调用。

pre_batch(**kwargs)

在训练批次之前调用。

pre_step(**kwargs)

在优化器的步骤之前调用。

register_training_loop(training_loop)

注册训练循环。

属性文档

loss

通过训练循环访问的损失。

model

通过训练循环访问的模型。

optimizer

优化器,通过训练循环访问。

result_tracker

结果跟踪器,通过训练循环访问。

training_loop

训练循环。

方法文档

on_batch(epoch: int, batch, batch_loss: float, **kwargs: Any) None[源代码]

调用训练批次。

Parameters:
Return type:

post_batch(epoch: int, batch, **kwargs: Any) None[源代码]

调用训练批次。

Parameters:
Return type:

post_epoch(epoch: int, epoch_loss: float, **kwargs: Any) None[source]

在epoch之后调用。

Parameters:
Return type:

post_train(losses: list[float], **kwargs: Any) None[来源]

训练后调用。

Parameters:
Return type:

pre_batch(**kwargs: Any) None[source]

在训练批次之前调用。

Parameters:

kwargs (Any)

Return type:

pre_step(**kwargs: Any) None[源代码]

在优化器的步骤之前调用。

Parameters:

kwargs (Any)

Return type:

register_training_loop(training_loop: TrainingLoop) None[source]

注册训练循环。

Parameters:

training_loop (TrainingLoop)

Return type: