torch_frame.nn.encoder.LinearEncoder
- class LinearEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None)[source]
基础类:
StypeEncoder
一个基于线性函数的数值特征编码器。它对每个原始数值特征应用线性层
torch.nn.Linear(1, out_channels)
并连接输出嵌入。请注意,该实现以批处理方式对所有数值特征执行此操作。