torch_frame.nn.encoder.EmbeddingEncoder

class EmbeddingEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None)[来源]

基础类:StypeEncoder

一种基于嵌入查找的分类特征编码器。它为每个分类特征应用torch.nn.Embedding并连接输出嵌入。

reset_parameters()[来源]

初始化post_module的参数。

encode_forward(feat: Tensor, col_names: list[str] | None = None) Tensor[来源]

主要的前向函数。将输入 feat 从 TensorFrame(形状 [batch_size, num_cols])映射到输出 x,其形状为 [batch_size, num_cols, out_channels]