torch_frame.nn.conv.FTTransformerConvs

class FTTransformerConvs(channels: int, feedforward_channels: int | None = None, num_layers: int = 3, nhead: int = 8, dropout: float = 0.2, activation: str = 'relu')[source]

基础类:TableConv

FT-Transformer 骨干网络在 “Revisiting Deep Learning Models for Tabular Data” 论文中。

该模块将一个可学习的CLS标记嵌入x_cls连接到输入张量x上,并在连接后的张量上应用多层Transformer。在Transformer层之后,输出张量被分为两部分:(1) x,对应于原始输入张量,和(2) x_cls,对应于CLS标记张量。

Parameters:
  • channels (int) – 输入/输出通道的维度

  • feedforward_channels (int, optional) – Transformer模型的前馈网络使用的隐藏通道数。如果为None,则将其设置为channels(默认值:None

  • num_layers (int) – 变压器编码器层的数量。(默认值:3)

  • nhead (int) – 多头注意力机制中的头数(默认值:8)

  • dropout (int) – 丢弃率值(默认:0.1)

  • activation (str) – 激活函数 (默认: relu)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]

CLS-token增强的Transformer卷积。

Parameters:

x (Tensor) – 输入张量,形状为 [batch_size, num_cols, channels]

Returns:

(输出形状为[batch_size, num_cols, channels]的张量,对应于输入列,输出形状为[batch_size, channels]的张量,对应于添加的CLS标记列。)

Return type:

(torch.Tensor, torch.Tensor)