变量选择网络#

class pytorch_forecasting.models.temporal_fusion_transformer.sub_modules.VariableSelectionNetwork(input_sizes: Dict[str, int], hidden_size: int, input_embedding_flags: Dict[str, bool] = None, dropout: float = 0.1, context_size: int = None, single_variable_grns: Dict[str, 门控残差网络] = None, prescalers: Dict[str, Linear] = None)[来源]#

基础: Module

计算每个大小为 input_sizenum_inputs 个变量的权重

方法

forward(x[, context])

定义每次调用时执行的计算。

forward(x: Dict[str, Tensor], context: Tensor = None)[来源]#

定义每次调用时执行的计算。

应该被所有子类重写。

注意

尽管前向传播的配方需要在此函数中定义,但之后应该调用 Module 实例而不是这个,因为前者负责运行注册的钩子,而后者则默默忽略它们。