torch_frame.nn.decoder.TromptDecoder

class TromptDecoder(in_channels: int, out_channels: int, num_prompts: int)[source]

基础类:Decoder

“Trompt: Towards a Better Deep Neural Network for Tabular Data”论文中介绍的Trompt下游。

Parameters:
  • in_channels (int) – 输入通道维度

  • out_channels (int) – 输出通道的维度

  • num_prompts (int) – 提示列的数量。

reset_parameters() None[source]

重置模块的所有可学习参数。

forward(x: Tensor) Tensor[source]

将形状为 [batch_size, num_cols, channels]x 解码为形状为 [batch_size, out_channels] 的输出张量。

Parameters:
  • x (torch.Tensor) – 输入列方向的张量,形状为 [batch_size, num_cols, hidden_channels]

  • args (Any) – 额外参数。

  • kwargs (Any) – 额外的关键字参数。