torch_frame.nn.encoder.StypeEncoder

class StypeEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None)[来源]

基础类:Module, ABC

stype编码器的基类。该模块将特定stype的张量,即TensorFrame.feat_dict[stype.xxx]转换为三维列式张量,该张量将输入到TableConv中。

Parameters:
  • out_channels (int) – 输出通道的维度

  • stats_list (list[dict[torch_frame.data.stats.StatType, Any]]) – 同一stype中每列的统计信息列表。

  • stype (stype) – 编码器输入的stype。

  • post_module (Module, 可选) – 应用于输出的后处理模块,例如激活函数和归一化。必须保持输出的形状。如果 None,则不会对输出应用任何模块。(默认值:None

  • na_strategy (NAStrategy, 可选) – 用于填补NaN值的策略。如果na_strategy为None,则对于NaN类别输出不可学习的全零嵌入。(默认值:None

abstract reset_parameters()[source]

初始化post_module的参数。

forward(feat: TensorData, col_names: list[str] | None = None) Tensor[source]

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

abstract encode_forward(feat: TensorData, col_names: list[str] | None = None) Tensor[来源]

主要的前向函数。将输入 feat 从 TensorFrame(形状 [batch_size, num_cols])映射到输出 x,其形状为 [batch_size, num_cols, out_channels]

post_forward(out: Tensor) Tensor[source]

应用于形状为 [batch_size, num_cols, channels] 的 out 的后处理函数。它还会返回相同形状的 out

na_forward(feat: Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]]) Union[Tensor, MultiNestedTensor, MultiEmbeddingTensor, dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]][来源]

根据给定的na_strategy替换输入TensorData中的NaN值。

Parameters:

feat (TensorData) – 输入 TensorData.

Returns:

输出 TensorData 并将 NaN 替换为给定的

na_strategy

Return type:

TensorData