预测回调#
- class pytorch_forecasting.models.base_model.PredictCallback(mode: str | Tuple[str, str] = 'prediction', return_index: bool = False, return_decoder_lengths: bool = False, return_y: bool = False, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'batch', return_x: bool = False, mode_kwargs: Dict[str, Any] = None, output_dir: str | None = None, predict_kwargs: Dict[str, Any] = None)[来源]#
基类:
BasePredictionWriter
内部使用的回调,用于捕获预测并可选择性地将其写入磁盘。
- Inherited-members:
方法
load_state_dict
(state_dict)在加载检查点时被调用,实现以重新加载回调状态,给定回调的
state_dict
。on_after_backward
(trainer, pl_module)在
loss.backward()
之后调用,并在优化器更新之前。on_before_backward
(培训者, pl_module, 损失)在
loss.backward()
之前调用。on_before_optimizer_step
(trainer, pl_module, ...)在
optimizer.step()
之前调用。on_before_zero_grad
(trainer, pl_module, ...)在
optimizer.zero_grad()
之前调用。on_exception
(训练器, pl_module, 异常)当任何训练器执行因异常中断时调用。
on_fit_end
(trainer, pl_module)当拟合结束时调用。
on_fit_start
(训练器, pl_module)当拟合开始时调用。
on_load_checkpoint
(训练器, pl_module, ...)在加载模型检查点时调用,用于重新加载状态。
on_predict_batch_end
(trainer, pl_module, ...)预测批次结束时调用。
on_predict_batch_start
(训练器, pl_module, ...)当预测批次开始时调用。
on_predict_end
(trainer, pl_module)预测结束时调用。
on_predict_epoch_end
(trainer, pl_module)当预测周期结束时调用。
on_predict_epoch_start
(trainer, pl_module)在预测周期开始时调用。
on_predict_start
(trainer, pl_module)当预测开始时调用。
on_sanity_check_end
(trainer, pl_module)在验证完 sanity 检查结束时被调用。
on_sanity_check_start
(trainer, pl_module)在验证的完整性检查开始时被调用。
on_save_checkpoint
(trainer, pl_module, ...)在保存检查点时调用,以让您有机会存储任何其他您可能想要保存的内容。
on_test_batch_end
(trainer, pl_module, ...[, ...])在测试批次结束时调用。
on_test_batch_start
(trainer, pl_module, ...)当测试批次开始时调用。
on_test_end
(训练器, pl_module)在测试结束时被调用。
on_test_epoch_end
(trainer, pl_module)在测试周期结束时被调用。
on_test_epoch_start
(trainer, pl_module)在测试纪元开始时被调用。
on_test_start
(trainer, pl_module)在测试开始时调用。
on_train_batch_end
(训练器, pl_module, ...)在训练批次结束时调用。
on_train_batch_start
(trainer, pl_module, ...)在训练批次开始时调用。
on_train_end
(trainer, pl_module)训练结束时调用。
on_train_epoch_end
(trainer, pl_module)当训练周期结束时调用。
on_train_epoch_start
(trainer, pl_module)当训练周期开始时调用。
on_train_start
(trainer, pl_module)当训练开始时被调用。
on_validation_batch_end
(训练器, pl_module, ...)当验证批次结束时调用。
on_validation_batch_start
(trainer, ...[, ...])当验证批处理开始时调用。
on_validation_end
(训练器, 模型)在验证周期结束时调用。
on_validation_epoch_end
(trainer, pl_module)在验证周期结束时调用。
on_validation_epoch_start
(trainer, pl_module)在验证周期开始时调用。
on_validation_start
(训练者, pl_module)当验证循环开始时调用。
setup
(trainer, pl_module, stage)在开始拟合、验证、测试、预测或调优时被调用。
state_dict
()在保存检查点时调用,实现以生成回调的
state_dict
。teardown
(trainer, pl_module, stage)在 fit、validate、test、predict 或 tune 结束时调用。
write_on_batch_end
(trainer, pl_module, ...)覆盖写入单个批次的逻辑。
write_on_epoch_end
(trainer, pl_module, ...)用来重写写入所有批次的逻辑。
属性
result
state_key
回调状态的标识符。
- on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None [来源]#
预测批次结束时调用。