torch_frame.nn.models.FTTransformer

class FTTransformer(channels: int, out_channels: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None)[来源]

基础类: Module

FT-Transformer模型在“重新审视用于表格数据的深度学习模型”论文中提出。

注意

有关使用FTTransformer的示例,请参见examples/revisiting.py

Parameters:
forward(tf: TensorFrame) Tensor[source]

TensorFrame对象转换为输出预测。

Parameters:

tf (TensorFrame) – 输入的 TensorFrame 对象。

Returns:

输出形状为 [batch_size, out_channels]。

Return type:

torch.Tensor