torch_frame.nn.models.TabNet

class TabNet(out_channels: int, num_layers: int, split_feat_channels: int, split_attn_channels: int, gamma: float, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None, num_shared_glu_layers: int = 2, num_dependent_glu_layers: int = 2, cat_emb_channels: int = 2)[来源]

基础类: Module

TabNet模型在 “TabNet: Attentive Interpretable Tabular Learning” 论文中提出。

注意

有关使用TabNet的示例,请参见examples/tabnet.py

Parameters:
  • out_channels (int) – 输出维度

  • num_layers (int) – TabNet 层数。

  • split_feat_channels (int) – 特征通道的维度。

  • split_attn_channels (int) – 注意力通道的维度。

  • gamma (float) – 用于更新注意力掩码先验的gamma值。

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – 一个将列名映射到统计信息的字典。 可通过 dataset.col_stats 获取。

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – 一个 将stype映射到列名列表的字典。列名根据 tensor_frame.feat_dict中出现的顺序进行排序。可通过 tensor_frame.col_names_dict访问。

  • stype_encoder_dict – (dict[torch_frame.stype, torch_frame.nn.encoder.StypeEncoder], 可选): 一个将stypes映射到其stype编码器的字典。 (默认: None, 将调用 EmbeddingEncoder() 用于分类特征和 StackEncoder() 用于 数值特征)

  • num_shared_glu_layers (int) – 在num_layers FeatureTransformer`s 之间共享的 GLU 层数。 (默认: :obj:`2)

  • num_dependent_glu_layers (int, optional) – 在每个num_layers FeatureTransformer中使用的GLU层数。 (默认值: :obj:`2)

  • cat_emb_channels (int, optional) – 分类嵌入的维度。

forward(tf: TensorFrame, return_reg: bool = False) Tensor | tuple[Tensor, Tensor][source]

TensorFrame对象转换为输出嵌入。

Parameters:
  • tf (TensorFrame) – 输入的 TensorFrame 对象。

  • return_reg (bool) – 是否返回熵正则化。

Returns:

输出

大小为 [batch_size, out_channels] 的嵌入。 如果 return_regTrue,则同时返回熵正则化。

Return type:

联合[torch.Tensor, (torch.Tensor, torch.Tensor)]