torch_frame.nn.models.FTTransformer
- class FTTransformer(channels: int, out_channels: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None)[来源]
基础类:
ModuleFT-Transformer模型在“重新审视用于表格数据的深度学习模型”论文中提出。
注意
有关使用FTTransformer的示例,请参见examples/revisiting.py。
- Parameters:
channels (int) – 隐藏通道维度
out_channels (int) – 输出通道的维度
num_layers (int) – 层数。 (默认:
3)col_stats (dict[str,dict[
torch_frame.data.stats.StatType,Any]]) – 一个将列名映射到统计信息的字典。 可作为dataset.col_stats使用。col_names_dict (dict[
torch_frame.stype, list[str]]) – 一个 将stype映射到列名列表的字典。列名根据在tensor_frame.feat_dict中出现的顺序进行排序。可通过tensor_frame.col_names_dict获取。stype_encoder_dict – (dict[
torch_frame.stype,torch_frame.nn.encoder.StypeEncoder], 可选): 一个将stypes映射到其stype编码器的字典。 (默认:None, 将调用torch_frame.nn.encoder.EmbeddingEncoder()用于分类 特征和torch_frame.nn.encoder.LinearEncoder()用于数值特征)
- forward(tf: TensorFrame) Tensor[source]
将
TensorFrame对象转换为输出预测。- Parameters:
tf (TensorFrame) – 输入的
TensorFrame对象。- Returns:
输出形状为 [batch_size, out_channels]。
- Return type: