NHiTSBlock#

class pytorch_forecasting.models.nhits.sub_modules.NHiTSBlock(context_length: int, prediction_length: int, output_size: int, encoder_covariate_size: int, decoder_covariate_size: int, static_size: int, static_hidden_size: int, n_theta: int, hidden_size: List[int], pooling_sizes: int, pooling_mode: str, basis: Module, n_layers: int, batch_normalization: bool, dropout: float, activation: str)[来源]#

基础: Module

N-HiTS块,它将基函数作为参数。

初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。

方法

forward(encoder_y, encoder_x_t, decoder_x_t, x_s)

定义每次调用时执行的计算。

forward(encoder_y: Tensor, encoder_x_t: Tensor, decoder_x_t: Tensor, x_s: Tensor) Tuple[Tensor, Tensor][来源]#

定义每次调用时执行的计算。

应该被所有子类重写。

注意

尽管前向传播的配方需要在此函数中定义,但之后应该调用 Module 实例而不是这个,因为前者负责运行注册的钩子,而后者则默默忽略它们。