torch_frame.nn.encoder.FeatureEncoder
- class FeatureEncoder(*args, **kwargs)[来源]
-
特征编码器的基类,将输入的
torch_frame.TensorFrame
转换为(x, col_names)
,其中x
是形状为[batch_size, num_cols, channels]
的列式 PyTorch 张量,而col_names
是列的名称。该类包含可学习的参数和缺失值处理。- abstract forward(tf: TensorFrame) tuple[torch.Tensor, list[str]] [来源]
将
TensorFrame
对象编码为元组(x, col_names)
。- Parameters:
tf (
torch_frame.TensorFrame
) – 输入的TensorFrame
对象。- Returns:
- 一个输出列的元组
torch.Tensor
的形状为[batch_size, num_cols, hidden_channels]
和一个x
的列名列表。长度需要为num_cols
。
- Return type:
(torch.Tensor, 列表[str])