torch_frame.transforms

转换

允许在不同 stype 之间或同一 stype 内进行数据转换。转换接受 TensorFrame 和列统计信息。

让我们看一个例子,我们应用CatToNumTransform将分类特征转换为数值特征。

from torch_frame.datasets import Yandex
from torch_frame.transforms import CatToNumTransform
from torch_frame import stype

dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
transform = CatToNumTransform()
train_dataset = dataset.get_split('train')

train_dataset.tensor_frame.col_names_dict[stype.categorical]
>>> ['C_feature_0', 'C_feature_1', 'C_feature_2', 'C_feature_3', 'C_feature_4', 'C_feature_5', 'C_feature_6', 'C_feature_7']

test_dataset = dataset.get_split('test')
transform.fit(train_dataset.tensor_frame, dataset.col_stats)

transformed_col_stats = transform.transformed_stats

transformed_col_stats.keys()
>>> dict_keys(['C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0'])

transformed_col_stats['C_feature_0_0']
>>> {<StatType.MEAN: 'MEAN'>: 0.6984029484029484, <StatType.STD: 'STD'>: 0.45895127199411595, <StatType.QUANTILES: 'QUANTILES'>: [0.0, 0.0, 1.0, 1.0, 1.0]}

transform(test_dataset.tensor_frame)
>>> TensorFrame(
      num_cols=14,
      num_rows=16281,
      numerical (14): ['N_feature_0', 'N_feature_1', 'N_feature_2', 'N_feature_3', 'N_feature_4', 'N_feature_5', 'C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0'],
      has_target=True,
      device=cpu,
    )

你可以看到,在转换之后,分类特征的列名发生了变化,并且分类特征被转换为数值特征。

BaseTransform

用于编写转换的抽象基类。

FittableBaseTransform

用于编写可拟合变换的抽象基类。

CatToNumTransform

使用目标统计量转换TensorFrame中的分类特征。

MutualInformationSort

一种对输入数值特征进行排序的转换