torch_frame.nn.encoder.FeatureEncoder

class FeatureEncoder(*args, **kwargs)[来源]

基础类:Module, ABC

特征编码器的基类,将输入的 torch_frame.TensorFrame 转换为 (x, col_names),其中 x 是形状为 [batch_size, num_cols, channels] 的列式 PyTorch 张量,而 col_names 是列的名称。该类包含可学习的参数和缺失值处理。

abstract forward(tf: TensorFrame) tuple[torch.Tensor, list[str]][来源]

TensorFrame对象编码为元组(x, col_names)

Parameters:

tf (torch_frame.TensorFrame) – 输入的 TensorFrame 对象。

Returns:

一个输出列的元组

torch.Tensor 的形状为 [batch_size, num_cols, hidden_channels] 和一个 x 的列名列表。长度需要为 num_cols

Return type:

(torch.Tensor, 列表[str])

reset_parameters() None[来源]

重置模块的所有可学习参数。