torch_frame.nn.conv.TabTransformerConv

class TabTransformerConv(channels: int, num_heads: int, attn_dropout: float = 0.0, ffn_dropout: float = 0.0)[source]

基础类:TableConv

TabTransformer层在“TabTransformer: 使用上下文嵌入的表格数据建模”论文中引入。

Parameters:
  • channels (int) – 输入/输出通道的维度

  • num_heads (int) – 注意力头的数量

  • attn_dropout (float) – 注意力模块的dropout(默认值:0.

  • ffn_dropout (float) – 注意力模块的dropout(默认值:0.

forward(x: Tensor) Tensor[source]

将列方向的3维张量处理为另一个列方向的3维张量。

Parameters:
  • x (torch.Tensor) – 输入列方向的张量,形状为 [batch_size, num_cols, hidden_channels]

  • args (Any) – 额外参数。

  • kwargs (Any) – 额外的关键字参数。

reset_parameters()[source]

重置模块的所有可学习参数。