torch_frame.nn.encoder.FeatureEncoder
- class FeatureEncoder(*args, **kwargs)[来源]
-
特征编码器的基类,将输入的
torch_frame.TensorFrame转换为(x, col_names),其中x是形状为[batch_size, num_cols, channels]的列式 PyTorch 张量,而col_names是列的名称。该类包含可学习的参数和缺失值处理。- abstract forward(tf: TensorFrame) tuple[torch.Tensor, list[str]][来源]
将
TensorFrame对象编码为元组(x, col_names)。- Parameters:
tf (
torch_frame.TensorFrame) – 输入的TensorFrame对象。- Returns:
- 一个输出列的元组
torch.Tensor的形状为[batch_size, num_cols, hidden_channels]和一个x的列名列表。长度需要为num_cols。
- Return type:
(torch.Tensor, 列表[str])