torch_frame.nn.encoder.StypeEncoder
- class StypeEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None)[来源]
基础类:
Module
,ABC
stype编码器的基类。该模块将特定stype的张量,即TensorFrame.feat_dict[stype.xxx]转换为三维列式张量,该张量将输入到
TableConv
中。- Parameters:
out_channels (int) – 输出通道的维度
stats_list (list[dict[torch_frame.data.stats.StatType, Any]]) – 同一stype中每列的统计信息列表。
stype (stype) – 编码器输入的stype。
post_module (Module, 可选) – 应用于输出的后处理模块,例如激活函数和归一化。必须保持输出的形状。如果
None
,则不会对输出应用任何模块。(默认值:None
)na_strategy (NAStrategy, 可选) – 用于填补NaN值的策略。如果na_strategy为None,则对于
NaN
类别输出不可学习的全零嵌入。(默认值:None
)
- forward(feat: TensorData, col_names: list[str] | None = None) Tensor [source]
定义每次调用时执行的计算。
应该由所有子类覆盖。
注意
尽管前向传递的配方需要在此函数内定义,但之后应该调用
Module
实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。
- abstract encode_forward(feat: TensorData, col_names: list[str] | None = None) Tensor [来源]
主要的前向函数。将输入
feat
从 TensorFrame(形状 [batch_size, num_cols])映射到输出x
,其形状为[batch_size, num_cols, out_channels]
。
- post_forward(out: Tensor) Tensor [source]
应用于形状为 [batch_size, num_cols, channels] 的
out
的后处理函数。它还会返回相同形状的out
。
- na_forward(feat: Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]]) Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]] [来源]
根据给定的
na_strategy
替换输入TensorData
中的NaN值。- Parameters:
feat (TensorData) – 输入
TensorData
.- Returns:
- 输出
TensorData
并将 NaN 替换为给定的 na_strategy
。
- 输出
- Return type:
TensorData