身份基#

class pytorch_forecasting.models.nhits.sub_modules.IdentityBasis(backcast_size: int, forecast_size: int, interpolation_mode: str)[来源]#

基础: Module

初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。

方法

forward(backcast_theta, forecast_theta, ...)

定义每次调用时执行的计算。

forward(backcast_theta: Tensor, forecast_theta: Tensor, encoder_x_t: Tensor, decoder_x_t: Tensor) Tuple[Tensor, Tensor][来源]#

定义每次调用时执行的计算。

应该被所有子类重写。

注意

尽管前向传播的配方需要在此函数中定义,但之后应该调用 Module 实例而不是这个,因为前者负责运行注册的钩子,而后者则默默忽略它们。