torch_frame.transforms.FittableBaseTransform

class FittableBaseTransform[source]

基础类: BaseTransform

用于编写可拟合变换的抽象基类。 可拟合变换在变换之前必须对训练数据进行拟合。

property is_fitted: bool

转换是否已经拟合。

fit(tf: TensorFrame, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]])[source]

使用训练数据拟合转换。

Parameters:
  • tf (TensorFrame) – 输入的 TensorFrame 对象,表示训练数据。

  • col_stats (Dict[str, Dict[StatType, Any]], optional) – 输入的 TensorFrame 的列统计信息。

forward(tf: TensorFrame) TensorFrame[source]

将TensorFrame对象处理成另一个TensorFrame对象。

Parameters:

tf (TensorFrame) – 输入的 TensorFrame

Returns:

转换后的输入 TensorFrame

Return type:

TensorFrame