torch_frame.nn.encoder.MultiCategoricalEmbeddingEncoder
- class MultiCategoricalEmbeddingEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None, mode: str = 'mean')[来源]
基础类:
StypeEncoder一个基于嵌入查找的多分类特征编码器。它为每个分类特征应用
torch.nn.EmbeddingBag并连接输出嵌入。- Parameters:
mode (str) – “sum”, “mean” 或 “max”。 指定减少包的方式。(默认:
mean)