NHiTS#

class pytorch_forecasting.models.nhits.sub_modules.NHiTS(context_length, prediction_length, output_size: int, static_size, encoder_covariate_size, decoder_covariate_size, static_hidden_size, n_blocks: list, n_layers: list, hidden_size: list, pooling_sizes: list, downsample_frequencies: list, pooling_mode, interpolation_mode, dropout, activation, initialization, batch_normalization, shared_weights, naive_level: bool)[来源]#

基础: Module

N-HiTS模型。

初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。

方法

create_stack(n_blocks, context_length, ...)

forward(encoder_y, encoder_mask, ...)

定义每次调用时执行的计算。

forward(encoder_y, encoder_mask, encoder_x_t, decoder_x_t, x_s)[来源]#

定义每次调用时执行的计算。

应该被所有子类重写。

注意

尽管前向传播的配方需要在此函数中定义,但之后应该调用 Module 实例而不是这个,因为前者负责运行注册的钩子,而后者则默默忽略它们。