NHiTS#
- class pytorch_forecasting.models.nhits.sub_modules.NHiTS(context_length, prediction_length, output_size: int, static_size, encoder_covariate_size, decoder_covariate_size, static_hidden_size, n_blocks: list, n_layers: list, hidden_size: list, pooling_sizes: list, downsample_frequencies: list, pooling_mode, interpolation_mode, dropout, activation, initialization, batch_normalization, shared_weights, naive_level: bool)[来源]#
基础:
Module
N-HiTS模型。
初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。
方法
create_stack
(n_blocks, context_length, ...)forward
(encoder_y, encoder_mask, ...)定义每次调用时执行的计算。