torch_frame.nn.conv.TabTransformerConv
- class TabTransformerConv(channels: int, num_heads: int, attn_dropout: float = 0.0, ffn_dropout: float = 0.0)[source]
基础类:
TableConv
TabTransformer层在“TabTransformer: 使用上下文嵌入的表格数据建模”论文中引入。
- Parameters:
- forward(x: Tensor) Tensor [source]
将列方向的3维张量处理为另一个列方向的3维张量。
- Parameters:
x (torch.Tensor) – 输入列方向的张量,形状为
[batch_size, num_cols, hidden_channels]
。args (Any) – 额外参数。
kwargs (Any) – 额外的关键字参数。