torch_frame.nn.encoder.LinearModelEncoder
- class LinearModelEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None, col_to_model_cfg: dict[str, ModelConfig] | None = None)[来源]
基础类:
StypeEncoder
基于线性函数的编码器,具有指定的模型输出嵌入特征。它在每个嵌入特征上应用一个线性层
torch.nn.Linear(in_channels, out_channels)
(in_channels
是嵌入的维度)并连接输出嵌入。model
也将与线性层一起训练。 请注意,实现是以批处理方式对所有列进行此操作的。- Parameters:
col_to_model_cfg (dict) – 一个将列名映射到
ModelConfig
的字典,它指定了一个模型 将形状为TensorData
的单列对象[batch_size, 1, *]
映射为形状为[batch_size, 1, model_out_channels]
的行嵌入。