torch_frame.nn.encoder.ExcelFormerEncoder

class ExcelFormerEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None)[source]

基础类:StypeEncoder

一个基于注意力的编码器,将输入的数字特征转换为三维张量。

在输入到嵌入层之前,数值特征被归一化,分类特征通过使用Sklearn Python包实现的CatBoost编码器转换为数值特征。然后根据互信息对特征进行排序。 原始编码在“ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data”论文中有所描述。

Parameters:
  • out_channels (int) – 输出通道的维度。

  • stats_list (list[dict[StatType, Any]]) – 同一stype中每列的统计信息列表。

encode_forward(feat: Tensor, col_names: list[str] | None = None) Tensor[source]

主要的前向函数。将输入 feat 从 TensorFrame(形状 [batch_size, num_cols])映射到输出 x,其形状为 [batch_size, num_cols, out_channels]

reset_parameters() None[source]

初始化post_module的参数。