torch_frame.nn.conv.ExcelFormerConv

class ExcelFormerConv(channels: int, num_cols: int, num_heads: int, diam_dropout: float = 0.0, aium_dropout: float = 0.0, residual_dropout: float = 0.0)[来源]

基础类:TableConv

ExcelFormer层在 “ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data” 论文中介绍。

Parameters:
  • channels (int) – 输入/输出通道的维度。

  • num_cols (int) – 列数。

  • num_heads (int) – 注意力头的数量。

  • diam_dropout (float) – diam_dropout。(默认值:0)

  • aium_dropout (float) – aium_dropout. (默认值: 0)

  • residual_dropout (float) – 残差丢弃率。(默认值:0)

reset_parameters() None[source]

重置模块的所有可学习参数。

forward(x: Tensor) Tensor[source]

将列方向的3维张量处理为另一个列方向的3维张量。

Parameters:
  • x (torch.Tensor) – 输入列方向的张量,形状为 [batch_size, num_cols, hidden_channels]

  • args (Any) – 额外参数。

  • kwargs (Any) – 额外的关键字参数。