NHiTSBlock#
- class pytorch_forecasting.models.nhits.sub_modules.NHiTSBlock(context_length: int, prediction_length: int, output_size: int, encoder_covariate_size: int, decoder_covariate_size: int, static_size: int, static_hidden_size: int, n_theta: int, hidden_size: List[int], pooling_sizes: int, pooling_mode: str, basis: Module, n_layers: int, batch_normalization: bool, dropout: float, activation: str)[来源]#
基础:
Module
N-HiTS块,它将基函数作为参数。
初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。
方法
forward
(encoder_y, encoder_x_t, decoder_x_t, x_s)定义每次调用时执行的计算。