预测回调#

class pytorch_forecasting.models.base_model.PredictCallback(mode: str | Tuple[str, str] = 'prediction', return_index: bool = False, return_decoder_lengths: bool = False, return_y: bool = False, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'batch', return_x: bool = False, mode_kwargs: Dict[str, Any] = None, output_dir: str | None = None, predict_kwargs: Dict[str, Any] = None)[来源]#

基类: BasePredictionWriter

内部使用的回调,用于捕获预测并可选择性地将其写入磁盘。

Inherited-members:

方法

load_state_dict(state_dict)

在加载检查点时被调用,实现以重新加载回调状态,给定回调的 state_dict

on_after_backward(trainer, pl_module)

loss.backward() 之后调用,并在优化器更新之前。

on_before_backward(培训者, pl_module, 损失)

loss.backward() 之前调用。

on_before_optimizer_step(trainer, pl_module, ...)

optimizer.step()之前调用。

on_before_zero_grad(trainer, pl_module, ...)

optimizer.zero_grad() 之前调用。

on_exception(训练器, pl_module, 异常)

当任何训练器执行因异常中断时调用。

on_fit_end(trainer, pl_module)

当拟合结束时调用。

on_fit_start(训练器, pl_module)

当拟合开始时调用。

on_load_checkpoint(训练器, pl_module, ...)

在加载模型检查点时调用,用于重新加载状态。

on_predict_batch_end(trainer, pl_module, ...)

预测批次结束时调用。

on_predict_batch_start(训练器, pl_module, ...)

当预测批次开始时调用。

on_predict_end(trainer, pl_module)

预测结束时调用。

on_predict_epoch_end(trainer, pl_module)

当预测周期结束时调用。

on_predict_epoch_start(trainer, pl_module)

在预测周期开始时调用。

on_predict_start(trainer, pl_module)

当预测开始时调用。

on_sanity_check_end(trainer, pl_module)

在验证完 sanity 检查结束时被调用。

on_sanity_check_start(trainer, pl_module)

在验证的完整性检查开始时被调用。

on_save_checkpoint(trainer, pl_module, ...)

在保存检查点时调用,以让您有机会存储任何其他您可能想要保存的内容。

on_test_batch_end(trainer, pl_module, ...[, ...])

在测试批次结束时调用。

on_test_batch_start(trainer, pl_module, ...)

当测试批次开始时调用。

on_test_end(训练器, pl_module)

在测试结束时被调用。

on_test_epoch_end(trainer, pl_module)

在测试周期结束时被调用。

on_test_epoch_start(trainer, pl_module)

在测试纪元开始时被调用。

on_test_start(trainer, pl_module)

在测试开始时调用。

on_train_batch_end(训练器, pl_module, ...)

在训练批次结束时调用。

on_train_batch_start(trainer, pl_module, ...)

在训练批次开始时调用。

on_train_end(trainer, pl_module)

训练结束时调用。

on_train_epoch_end(trainer, pl_module)

当训练周期结束时调用。

on_train_epoch_start(trainer, pl_module)

当训练周期开始时调用。

on_train_start(trainer, pl_module)

当训练开始时被调用。

on_validation_batch_end(训练器, pl_module, ...)

当验证批次结束时调用。

on_validation_batch_start(trainer, ...[, ...])

当验证批处理开始时调用。

on_validation_end(训练器, 模型)

在验证周期结束时调用。

on_validation_epoch_end(trainer, pl_module)

在验证周期结束时调用。

on_validation_epoch_start(trainer, pl_module)

在验证周期开始时调用。

on_validation_start(训练者, pl_module)

当验证循环开始时调用。

setup(trainer, pl_module, stage)

在开始拟合、验证、测试、预测或调优时被调用。

state_dict()

在保存检查点时调用,实现以生成回调的 state_dict

teardown(trainer, pl_module, stage)

在 fit、validate、test、predict 或 tune 结束时调用。

write_on_batch_end(trainer, pl_module, ...)

覆盖写入单个批次的逻辑。

write_on_epoch_end(trainer, pl_module, ...)

用来重写写入所有批次的逻辑。

属性

result

state_key

回调状态的标识符。

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[来源]#

预测批次结束时调用。

on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[来源]#

当预测周期结束时调用。

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)[来源]#

用逻辑重写以写入单个批次。

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[来源]#

覆盖写入所有批次的逻辑。