torch_frame.nn.conv.FTTransformerConvs
- class FTTransformerConvs(channels: int, feedforward_channels: int | None = None, num_layers: int = 3, nhead: int = 8, dropout: float = 0.2, activation: str = 'relu')[source]
基础类:
TableConvFT-Transformer 骨干网络在 “Revisiting Deep Learning Models for Tabular Data” 论文中。
该模块将一个可学习的CLS标记嵌入
x_cls连接到输入张量x上,并在连接后的张量上应用多层Transformer。在Transformer层之后,输出张量被分为两部分:(1)x,对应于原始输入张量,和(2)x_cls,对应于CLS标记张量。- Parameters:
- forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]
CLS-token增强的Transformer卷积。
- Parameters:
x (Tensor) – 输入张量,形状为 [batch_size, num_cols, channels]
- Returns:
(输出形状为[batch_size, num_cols, channels]的张量,对应于输入列,输出形状为[batch_size, channels]的张量,对应于添加的CLS标记列。)
- Return type: