变量选择网络#
- class pytorch_forecasting.models.temporal_fusion_transformer.sub_modules.VariableSelectionNetwork(input_sizes: Dict[str, int], hidden_size: int, input_embedding_flags: Dict[str, bool] = None, dropout: float = 0.1, context_size: int = None, single_variable_grns: Dict[str, 门控残差网络] = None, prescalers: Dict[str, Linear] = None)[来源]#
基础:
Module
计算每个大小为
input_size
的num_inputs
个变量的权重方法
forward
(x[, context])定义每次调用时执行的计算。