torch_frame.nn.encoder.StypeWiseFeatureEncoder
- class StypeWiseFeatureEncoder(out_channels: int, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder])[source]
基础类:
FeatureEncoder
特征编码器,将每个stype张量转换为嵌入并执行最终的连接。
- Parameters:
out_channels (int) – 输出维度。
col_stats – (dict[str, dict[
torch_frame.data.stats.StatType
, Any]]): 一个将列名映射到统计信息的字典。可通过dataset.col_stats
获取。col_names_dict (dict[
torch_frame.stype
, list[str]]) – 一个 将stype映射到列名列表的字典。列名根据在tensor_frame.feat_dict
中出现的顺序进行排序。 可作为tensor_frame.col_names_dict
使用。stype_encoder_dict – (dict[
torch_frame.stype
,torch_frame.nn.encoder.StypeEncoder
]): 一个将torch_frame.stype
映射到torch_frame.nn.encoder.StypeEncoder
类的字典。仅支持 父stypes
作为键。
- forward(tf: TensorFrame) tuple[torch.Tensor, list[str]] [source]
将
TensorFrame
对象编码为元组(x, col_names)
。- Parameters:
tf (
torch_frame.TensorFrame
) – 输入的TensorFrame
对象。- Returns:
- 一个输出列的元组
torch.Tensor
的形状为[batch_size, num_cols, hidden_channels]
和一个x
的列名列表。长度需要为num_cols
。
- Return type:
(torch.Tensor, 列表[str])