torch_frame.nn.models.TabTransformer

class TabTransformer(channels: int, out_channels: int, num_layers: int, num_heads: int, encoder_pad_size: int, attn_dropout: float, ffn_dropout: float, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]])[source]

基础类: Module

Tab-Transformer模型在 “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” 论文中提出。

该模型在分类特征嵌入中添加了一个列位置嵌入,并仅在分类特征上执行多层列交互建模。对于数值特征,模型简单地对输入特征应用层归一化。该模型利用MLP(多层感知器)进行解码。

注意

有关使用TabTransformer的示例,请参见examples/tabtransformer.py

Parameters:
  • channels (int) – 输入通道的维度。

  • out_channels (int) – 输出通道的维度。

  • num_layers (int) – 卷积层的数量。

  • num_heads (int) – 自注意力层中的头数。

  • encoder_pad_size (int) – 位置编码填充到分类嵌入的大小。

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – 一个将列名映射到统计信息的字典。 可作为 dataset.col_stats 使用。

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – 一个 将stype映射到列名列表的字典。列名根据 tensor_frame.feat_dict中出现的顺序进行排序。可通过 tensor_frame.col_names_dict获取。

forward(tf: TensorFrame) Tensor[来源]

TensorFrame对象转换为输出预测。

Parameters:

tf (TensorFrame) – 输入的 TensorFrame 对象。

Returns:

输出形状为 [batch_size, out_channels]。

Return type:

torch.Tensor