torch_frame.nn.models.Trompt

class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]

基础类: Module

“Trompt: Towards a Better Deep Neural Network for Tabular Data”论文中介绍的Trompt模型。

注意

有关使用Trompt的示例,请参见examples/trompt.py

Parameters:
  • channels (int) – 隐藏通道维度

  • out_channels (int) – 输出通道的维度

  • num_prompts (int) – 提示列的数量。

  • num_layers (int, optional) – TromptConv 层的数量。 (默认: 6)

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – 一个将列名映射到统计信息的字典。 可作为 dataset.col_stats 使用。

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – 一个 将stype映射到列名列表的字典。列名根据 tensor_frame.feat_dict中出现的顺序进行排序。可通过 tensor_frame.col_names_dict访问。

  • stype_encoder_dicts – (list[dict[torch_frame.stype, torch_frame.nn.encoder.StypeEncoder]], 可选): 一个包含num_layers个字典的列表,每个字典将 stypes映射到它们的stype编码器。 (默认: None, 将调用EmbeddingEncoder() 用于分类特征和LinearEncoder()用于 数值特征)

forward(tf: TensorFrame) Tensor[来源]

TensorFrame对象转换为每层的输出预测序列。在训练期间用于计算逐层损失。

Parameters:

tf (torch_frame.TensorFrame) – 输入的 TensorFrame 对象。

Returns:

输出预测在层之间堆叠。其

形状为 [batch_size, num_layers, out_channels]

Return type:

torch.Tensor