torch_frame.nn.encoder.LinearPeriodicEncoder

class LinearPeriodicEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = None, n_bins: int | None = 16)[source]

基础类:StypeEncoder

一个周期性编码器,利用正弦函数将输入张量转换为三维张量。编码使用可训练的参数定义,并包括正弦和余弦函数的应用。原始编码在“On Embeddings for Numerical Features in Tabular Deep Learning”中描述。

Parameters:

n_bins (int) – 用于周期性编码的箱数。

reset_parameters() None[source]

初始化post_module的参数。

encode_forward(feat: Tensor, col_names: list[str] | None = None) Tensor[source]

主要的前向函数。将输入 feat 从 TensorFrame(形状 [batch_size, num_cols])映射到输出 x,其形状为 [batch_size, num_cols, out_channels]