torch_frame.nn.encoder.TimestampEncoder

class TimestampEncoder(out_channels: int | None = None, stats_list: list[dict[StatType, Any]] | None = None, stype: stype | None = None, post_module: torch.nn.Module | None = None, na_strategy: NAStrategy | None = NAStrategy.MEDIAN_TIMESTAMP, out_size: int = 8)[source]

基础类:StypeEncoder

TimestampEncoder 用于时间戳类型。年份使用 torch_frame.nn.encoding.PositionalEncoding 进行编码。其他 特征,包括月份、日期、星期几、小时、分钟和秒, 使用 torch_frame.nn.encoding.CyclicEncoding 进行编码。 它以批处理方式对每列应用线性层。TimestampEncoder 不支持 NaN 时间戳,因为 torch_frame.nn.encoding.PositionalEncoding 不支持 负张量值。因此 torch_frame.NAStrategy.MEDIAN_TIMESTAMP 被应用为默认的 NAStrategy

Parameters:

out_size (int) – 位置编码和循环编码的输出维度。

reset_parameters() None[source]

初始化post_module的参数。

encode_forward(feat: Tensor, col_names: list[str] | None = None) Tensor[source]

主要的前向函数。将输入 feat 从 TensorFrame(形状 [batch_size, num_cols])映射到输出 x,其形状为 [batch_size, num_cols, out_channels]