torch_frame.nn.encoder.StypeWiseFeatureEncoder

class StypeWiseFeatureEncoder(out_channels: int, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder])[source]

基础类:FeatureEncoder

特征编码器,将每个stype张量转换为嵌入并执行最终的连接。

Parameters:
forward(tf: TensorFrame) tuple[torch.Tensor, list[str]][source]

TensorFrame对象编码为元组(x, col_names)

Parameters:

tf (torch_frame.TensorFrame) – 输入的 TensorFrame 对象。

Returns:

一个输出列的元组

torch.Tensor 的形状为 [batch_size, num_cols, hidden_channels] 和一个 x 的列名列表。长度需要为 num_cols

Return type:

(torch.Tensor, 列表[str])