torch_frame.nn.decoder.Decoder

class Decoder(*args, **kwargs)[来源]

基础类:Module, ABC

解码器的基类,用于将输入的列式PyTorch张量转换为应用预测头的输出张量。

abstract forward(x: Tensor, *args: Any, **kwargs: Any) Any[来源]

将形状为 [batch_size, num_cols, channels]x 解码为形状为 [batch_size, out_channels] 的输出张量。

Parameters:
  • x (torch.Tensor) – 输入列方向的张量,形状为 [batch_size, num_cols, hidden_channels]

  • args (Any) – 额外参数。

  • kwargs (Any) – 额外的关键字参数。

reset_parameters() None[来源]

重置模块的所有可学习参数。