torch_frame.nn.models.TabTransformer
- class TabTransformer(channels: int, out_channels: int, num_layers: int, num_heads: int, encoder_pad_size: int, attn_dropout: float, ffn_dropout: float, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]])[source]
基础类:
ModuleTab-Transformer模型在 “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” 论文中提出。
该模型在分类特征嵌入中添加了一个列位置嵌入,并仅在分类特征上执行多层列交互建模。对于数值特征,模型简单地对输入特征应用层归一化。该模型利用MLP(多层感知器)进行解码。
注意
有关使用TabTransformer的示例,请参见examples/tabtransformer.py。
- Parameters:
channels (int) – 输入通道的维度。
out_channels (int) – 输出通道的维度。
num_layers (int) – 卷积层的数量。
num_heads (int) – 自注意力层中的头数。
encoder_pad_size (int) – 位置编码填充到分类嵌入的大小。
col_stats (Dict[str,Dict[
torch_frame.data.stats.StatType,Any]]) – 一个将列名映射到统计信息的字典。 可作为dataset.col_stats使用。col_names_dict (Dict[
torch_frame.stype, List[str]]) – 一个 将stype映射到列名列表的字典。列名根据tensor_frame.feat_dict中出现的顺序进行排序。可通过tensor_frame.col_names_dict获取。
- forward(tf: TensorFrame) Tensor[来源]
将
TensorFrame对象转换为输出预测。- Parameters:
tf (TensorFrame) – 输入的
TensorFrame对象。- Returns:
输出形状为 [batch_size, out_channels]。
- Return type: