torch_frame.nn.conv.TromptConv

class TromptConv(channels: int, num_cols: int, num_prompts: int, num_groups: int = 2)[source]

基础类:TableConv

“Trompt: Towards a Better Deep Neural Network for Tabular Data”论文中介绍的Trompt单元。

Parameters:
  • channels (int) – 输入/输出通道的维度

  • num_cols (int) – 列数

  • num_prompts (int) – 提示列的数量。

  • num_groups (int) – 组归一化中的组数。(默认值:2

reset_parameters()[来源]

重置模块的所有可学习参数。

forward(x: Tensor, x_prompt: Tensor) Tensor[source]

xx_prompt 转换为下一层的 x_prompt

Parameters:
  • x (torch.Tensor) – 基于特征的嵌入,形状为 [batch_size, num_cols, channels]

  • x_prompt (torch.Tensor) – 输入提示嵌入的形状为 [batch_size, num_prompts, channels]

Returns:

输出下一层的提示嵌入。其

形状为 [batch_size, num_prompts, channels]

Return type:

torch.Tensor