使用 vaex.ml 进行机器学习#

如果你想在实时的 Python 内核中尝试这个笔记本,请使用 mybinder:

https://mybinder.org/badge_logo.svg

vaex.ml 包为 vaex 带来了一些机器学习算法。如果你安装了单独的子包(vaex-corevaex-hdf5,…)而不是 vaex 元包,你可能需要通过运行 pip install vaex-mlconda install -c conda-forge vaex-ml 来安装它。

vaex.ml 的 API 与 scikit-learn 的 API 非常接近,同时提供了更好的性能,并且能够高效地对大于可用内存的数据执行操作。本页是对 vaex.ml 提供的功能的概述和简要介绍。

[1]:
import vaex
vaex.multithreading.thread_count_default = 8
import vaex.ml

import numpy as np
import matplotlib.pyplot as plt

我们将使用著名的Iris flower和泰坦尼克号乘客名单数据集,这两个经典的机器学习演示数据集。

[2]:
df = vaex.datasets.iris()
df
[2]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别
0 5.9 3.0 4.2 1.5 1
1 6.1 3.0 4.6 1.4 1
2 6.6 2.9 4.6 1.3 1
3 6.7 3.3 5.7 2.1 2
4 5.5 4.2 1.4 0.2 0
... ... ... ... ... ...
1455.2 3.4 1.4 0.2 0
1465.1 3.8 1.6 0.2 0
1475.8 2.6 4.0 1.2 1
1485.7 3.8 1.7 0.3 0
1496.2 2.9 4.3 1.3 1
[3]:
df.scatter(df.petal_length, df.petal_width, c_expr=df.class_);
/home/jovan/vaex/packages/vaex-core/vaex/viz/mpl.py:205: UserWarning: `scatter` is deprecated and it will be removed in version 5.x. Please use `df.viz.scatter` instead.
  warnings.warn('`scatter` is deprecated and it will be removed in version 5.x. Please use `df.viz.scatter` instead.')
_images/tutorial_ml_5_1.png

预处理#

数值特征的缩放#

vaex.ml 包含了常见的数值缩放器:

  • vaex.ml.StandardScaler - 通过去除均值并除以方差来缩放特征;

  • vaex.ml.MinMaxScaler - 将特征缩放到给定范围;

  • vaex.ml.RobustScaler - 通过移除中位数并根据给定的百分位范围进行缩放来标准化特征;

  • vaex.ml.MaxAbsScaler - 通过它们的最大绝对值来缩放特征。

使用方法与scikit-learn非常相似,每个转换器都实现了.fit.transform方法。

[4]:
features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
scaler = vaex.ml.StandardScaler(features=features, prefix='scaled_')
scaler.fit(df)
df_trans = scaler.transform(df)
df_trans
[4]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 缩放后的花瓣长度 缩放后的花瓣宽度 缩放后的花萼长度 缩放后的花萼宽度
0 5.9 3.0 4.2 1.5 1 0.25096730693923325 0.39617188299171285 0.06866179325140277 -0.12495760117130607
1 6.1 3.0 4.6 1.4 1 0.4784301228962429 0.26469891297233916 0.3109975341387059 -0.12495760117130607
2 6.6 2.9 4.6 1.3 1 0.4784301228962429 0.13322594295296575 0.9168368863569659 -0.3563605663033572
3 6.7 3.3 5.7 2.1 2 1.1039528667780207 1.1850097031079545 1.0380047568006185 0.5692512942248463
4 5.5 4.2 1.4 0.2 0 -1.341272404759837 -1.3129767272601438 -0.4160096885232057 2.6518779804133055
... ... ... ... ... ... ... ... ... ...
1455.2 3.4 1.4 0.2 0 -1.341272404759837 -1.3129767272601438 -0.7795132998541615 0.8006542593568975
1465.1 3.8 1.6 0.2 0 -1.2275409967813318 -1.3129767272601438 -0.9006811702978141 1.726266119885101
1475.8 2.6 4.0 1.2 1 0.13723589896072813 0.0017529729335920385-0.052506077192249874-1.0505694616995096
1485.7 3.8 1.7 0.3 0 -1.1706752927920796 -1.18150375724077 -0.17367394763590144 1.726266119885101
1496.2 2.9 4.3 1.3 1 0.30783301092848553 0.13322594295296575 0.4321654045823586 -0.3563605663033572

任何vaex.ml转换器的.transform方法的输出是一个包含转换结果特征的DataFrame的浅拷贝,同时保留原始列。浅拷贝意味着这个新的DataFrame只是引用了原始DataFrame,没有使用额外的内存。此外,结果特征,在这种情况下是缩放的数值特征,是虚拟列,它们不占用任何内存,但在需要时会即时计算。这种方法非常适合处理非常大的数据集。

分类特征的编码#

vaex.ml 包含几种分类编码器:

  • vaex.ml.LabelEncoder - 使用与类别数量相同的整数对特征进行编码,从0开始;

  • vaex.ml.OneHotEncoder - 根据独热编码方案对特征进行编码;

  • vaex.ml.MultiHotEncoder - 根据多热编码方案(二进制向量)对特征进行编码;

  • vaex.ml.FrequencyEncoder - 通过各自类别的频率对特征进行编码;

  • vaex.ml.BayesianTargetEncoder - 使用目标值的均值对类别进行编码;

  • vaex.ml.WeightOfEvidenceEncoder - 将类别编码为其证据权重值。

以下是一个使用泰坦尼克号数据集的快速示例。

[5]:
df =  vaex.datasets.titanic()
df.head(5)
[5]:
# 舱位等级是否生还 姓名 性别 年龄 兄弟姐妹/配偶数量 父母/子女数量 票号 票价舱位 登船港口 救生艇 尸体编号家庭目的地
0 1True 艾伦, 小姐. 伊丽莎白·沃尔顿 女性29 0 0 24160211.338B5 S 2 nan圣路易斯, 密苏里州
1 1True Allison, Master. Hudson Trevor 0.9167 1 2 113781151.55 C22 C26S 11 nan蒙特利尔, PQ / 切斯特维尔, ON
2 1False 艾莉森,海伦·洛兰小姐 女性 2 1 2 113781151.55 C22 C26S -- nan蒙特利尔,魁北克 / 切斯特维尔,安大略
3 1False Allison, Mr. Hudson Joshua Creighton 30 1 2 113781151.55 C22 C26S -- 135蒙特利尔, PQ / 切斯特维尔, ON
4 1False Allison, Mrs. Hudson J C (Bessie Waldo Daniels)女性25 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / 切斯特维尔, ON
[6]:
label_encoder = vaex.ml.LabelEncoder(features=['embarked'])
one_hot_encoder = vaex.ml.OneHotEncoder(features=['embarked'])
multi_hot_encoder = vaex.ml.MultiHotEncoder(features=['embarked'])
freq_encoder = vaex.ml.FrequencyEncoder(features=['embarked'])
bayes_encoder = vaex.ml.BayesianTargetEncoder(features=['embarked'], target='survived')
woe_encoder = vaex.ml.WeightOfEvidenceEncoder(features=['embarked'], target='survived')

df = label_encoder.fit_transform(df)
df = one_hot_encoder.fit_transform(df)
df = multi_hot_encoder.fit_transform(df)
df = freq_encoder.fit_transform(df)
df = bayes_encoder.fit_transform(df)
df = woe_encoder.fit_transform(df)

df.head(5)
[6]:
# 舱位等级是否生还 姓名 性别 年龄 兄弟姐妹/配偶数量 父母/子女数量 票号 票价舱位 登船港口 救生艇 遗体编号家庭目的地 登船港口标签编码 登船港口缺失 登船港口C 登船港口Q 登船港口S 登船港口0 登船港口1 登船港口2 登船港口频率编码 登船港口均值编码 登船港口WOE编码
0 1True 艾伦,伊丽莎白·沃尔顿小姐 女性29 0 0 24160211.338B5 S 2 nan圣路易斯,密苏里州 1 0 0 0 1 1 0 0 0.698243 0.337472 -0.696431
1 1True Allison, Master. Hudson Trevor 0.9167 1 2 113781151.55 C22 C26S 11 nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 1 0 0 0.698243 0.337472 -0.696431
2 1False Allison, Miss. Helen Loraine 女性 2 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 1 0 0 0.698243 0.337472 -0.696431
3 1False Allison, Mr. Hudson Joshua Creighton 30 1 2 113781151.55 C22 C26S -- 135蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 1 0 0 0.698243 0.337472 -0.696431
4 1False Allison, Mrs. Hudson J C (Bessie Waldo Daniels)female25 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / 切斯特维尔, ON 1 0 0 0 1 1 0 0 0.698243 0.337472 -0.696431

请注意,转换后的特征都包含在生成的DataFrame中,并且被适当地命名。这对于构建各种诊断图和设计更复杂的特征非常有用。生成的(编码的)特征不占用内存的事实,使得人们可以尝试或组合各种预处理步骤,而无需花费任何额外的内存。

特征工程#

KBinsDiscretizer#

使用KBinsDiscretizer,您可以通过将数据分箱到指定的区间来将连续特征转换为离散特征。您可以指定箱子的数量以及确定它们大小的策略:

  • “uniform” - 所有箱子的大小相同;

  • “quantile” - 所有箱子中都有(大约)相同数量的样本;

  • “kmeans” - 每个区间中的值属于由KMeans算法确定的同一一维聚类。

[7]:
kbdisc = vaex.ml.KBinsDiscretizer(features=['age'], n_bins=5, strategy='quantile')
df = kbdisc.fit_transform(df)
df.head(5)
/home/jovan/vaex/packages/vaex-core/vaex/ml/transformations.py:1089: UserWarning: Bins whose width are too small (i.e., <= 1e-8) in   age are removed.Consider decreasing the number of bins.
  warnings.warn(f'Bins whose width are too small (i.e., <= 1e-8) in   {feat} are removed.'
[7]:
# 舱位等级是否生还 姓名 性别 年龄 兄弟姐妹/配偶数量 父母/子女数量 票号 票价舱位 登船港口 救生艇 尸体编号家庭目的地 登船港口标签编码 登船港口缺失 登船港口C 登船港口Q 登船港口S 登船港口频率编码 登船港口均值编码 登船港口WOE编码 年龄分箱
0 1True 艾伦, 小姐. 伊丽莎白·沃尔顿 女性29 0 0 24160211.338B5 S 2 nan圣路易斯, 密苏里州 1 0 0 0 1 0.698243 0.337472 -0.696431 0
1 1True Allison, Master. Hudson Trevor 0.9167 1 2 113781151.55 C22 C26S 11 nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0
2 1False Allison, Miss. Helen Loraine 女性 2 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0
3 1False Allison, Mr. Hudson Joshua Creighton 30 1 2 113781151.55 C22 C26S -- 135蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0
4 1False Allison, Mrs. Hudson J C (Bessie Waldo Daniels)female25 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / 切斯特维尔, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0

分组转换器#

GroupByTransformervaex-ml 中的一个便捷功能,它允许您对训练数据执行分组聚合操作,然后将这些聚合用作训练集和测试集中的特征。

[8]:
gbt = vaex.ml.GroupByTransformer(by='pclass', agg={'age': ['mean', 'std'],
                                                   'fare': ['mean', 'std'],
                                                  })
df = gbt.fit_transform(df)
df.head(5)
[8]:
# 舱位等级是否生还 姓名 性别 年龄 兄弟姐妹/配偶数量 父母/子女数量 票号 票价舱位 登船港口 救生艇 尸体编号家庭目的地 登船港口编码 登船港口缺失 登船港口C 登船港口Q 登船港口S 登船港口频率编码 登船港口均值编码 登船港口WOE编码 年龄分箱 年龄均值 年龄标准差 票价均值 票价标准差
0 1True 艾伦, 小姐. 伊丽莎白·沃尔顿 女性29 0 0 24160211.338B5 S 2 nan圣路易斯, 密苏里州 1 0 0 0 1 0.698243 0.337472 -0.696431 0 39.1599 14.5224 87.509 80.3226
1 1True Allison, Master. Hudson Trevor 0.9167 1 2 113781151.55 C22 C26S 11 nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0 39.1599 14.5224 87.509 80.3226
2 1False Allison, Miss. Helen Loraine 女性 2 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / Chesterville, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0 39.1599 14.5224 87.509 80.3226
3 1False Allison, Mr. Hudson Joshua Creighton 30 1 2 113781151.55 C22 C26S -- 135蒙特利尔, PQ / 切斯特维尔, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0 39.1599 14.5224 87.509 80.3226
4 1False Allison, Mrs. Hudson J C (Bessie Waldo Daniels)女性25 1 2 113781151.55 C22 C26S -- nan蒙特利尔, PQ / 切斯特维尔, ON 1 0 0 0 1 0.698243 0.337472 -0.696431 0 39.1599 14.5224 87.509 80.3226

循环变换器#

CycleTransformer 提供了一种转换周期性特征的策略,例如角度或时间。这是通过将每个特征视为描述极坐标系,并将其转换为笛卡尔坐标系来实现的。这被证明有助于某些机器学习模型实现更好的性能。

[9]:
df = vaex.from_arrays(days=[0, 1, 2, 3, 4, 5, 6])
cyctrans = vaex.ml.CycleTransformer(n=7, features=['days'])
cyctrans.fit_transform(df)
[9]:
# 天数 天数_x 天数_y
0 0 1 0
1 1 0.62349 0.781831
2 2-0.222521 0.974928
3 3-0.900969 0.433884
4 4-0.900969-0.433884
5 5-0.222521-0.974928
6 6 0.62349 -0.781831

降维#

主成分分析#

vaex.ml中实现的PCA可以扩展到非常大量的样本,即使我们想要转换的数据不适合放入RAM。为了演示这一点,让我们在Iris数据集上进行PCA转换。在这个例子中,我们已经将这个数据集复制了数千次,使其包含超过10亿个样本。

[10]:
df = vaex.datasets.iris_1e9()
n_samples = len(df)
print(f'Number of samples in DataFrame: {n_samples:,}')
Number of samples in DataFrame: 1,005,000,000
[11]:
features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
pca = vaex.ml.PCA(features=features, n_components=4)
pca.fit(df, progress='widget')

vaex.ml中实现的PCA转换器可以在不到一分钟的时间内很好地拟合,即使数据包含4列和10亿行。

[12]:
df_trans = pca.transform(df)
df_trans
[12]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 PCA_0 PCA_1 PCA_2 PCA_3
0 5.9 3.0 4.2 1.5 1 -0.51109806050657190.10228410590320294 0.13232789125239366 -0.05010053260756789
1 6.1 3.0 4.6 1.4 1 -0.89016044564845710.03381244269907491 -0.0097680289049917950.1534482059864868
2 6.6 2.9 4.6 1.3 1 -1.0432977809309918-0.2289569106597803 -0.41481456509035997 0.03752354509774891
3 6.7 3.3 5.7 2.1 2 -2.275853649246034 -0.3333865237191275 0.28467815436304544 0.062230281630705805
4 5.5 4.2 1.4 0.2 0 2.5971594768136956 -1.1000219282272325 0.16358191524058419 0.09895807321522321
... ... ... ... ... ... ... ... ... ...
1,004,999,9955.2 3.4 1.4 0.2 0 2.6398212682948925 -0.3192900674870881 -0.1392533720548284 -0.06514104909063131
1,004,999,9965.1 3.8 1.6 0.2 0 2.537573370908207 -0.5103675457748862 0.17191840236558648 0.19216594960009262
1,004,999,9975.8 2.6 4.0 1.2 1 -0.22887904987726520.4022576190683287 -0.22736270650701024 -0.01862045442675292
1,004,999,9985.7 3.8 1.7 0.3 0 2.199077961161723 -0.8792440894091085 -0.11452146077196179 -0.025326942106218664
1,004,999,9996.2 2.9 4.3 1.3 1 -0.6416902782168139-0.019071177408365406-0.20417287674016232 0.02050967222367117

回想一下,包含PCA组件的转换后的DataFrame不会占用额外的内存。

增量PCA#

vaex中的PCA实现非常快,但对于“高”DataFrames(即具有许多行但列数不多的DataFrames)来说更是如此。对于具有数百列的DataFrames,使用增量PCA方法更为高效。vaex.ml提供了一个方便的方法,基本上封装了sklearn.decomposition.IncrementalPCA,其拟合对于“宽”DataFrames更为高效。

使用方法与常规的PCA方法几乎相同。考虑以下示例:

[13]:
n_samples = 100_000
n_columns = 50
data_dict = {f'feat_{i}': np.random.normal(0, i+1, size=n_samples) for i in range(n_columns)}
df = vaex.from_dict(data_dict)


features = df.get_column_names()
pca = vaex.ml.PCAIncremental(n_components=10, features=features, batch_size=42_000)
pca.fit(df, progress='widget')
pca.transform(df)
[13]:
# feat_0 feat_1 feat_2 feat_3 feat_4 feat_5 feat_6 feat_7 feat_8 feat_9 feat_10 feat_11 feat_12 feat_13 feat_14 feat_15 feat_16 feat_17 feat_18 feat_19 feat_20 feat_21 feat_22 feat_23 feat_24 feat_25 feat_26 feat_27 feat_28 feat_29 feat_30 feat_31 feat_32 feat_33 feat_34 feat_35 feat_36 feat_37 feat_38 feat_39 feat_40 feat_41 feat_42 feat_43 feat_44 feat_45 feat_46 feat_47 feat_48 feat_49 PCA_0 PCA_1 PCA_2 PCA_3 PCA_4 PCA_5 PCA_6 PCA_7 PCA_8 PCA_9
0 0.21916619701436382-1.1435438188965208-2.236473242690611 -8.81728920352771 1.9931414225984159 0.8289809515418928 -7.847441537857684 -5.990636964340006 0.43889103534482576-6.4855757436955965-14.48532696768287113.825392548457543 -5.5661773929038185-3.1816868599382633 27.66565101972783650.541940500115366 16.001390451665785 32.510983357481614 8.342038455860216 -1.7293759207235855-6.451472523437187 22.55340570655327 -2.543125122041264528.75425936065127 -39.487762558467345 -6.871003398404642 11.198673922236354 -86.63832306461876 -7.32368079105989237.35407351193795 23.653897939827836 39.52047029873747 42.79143756690254 -33.3810495394693 33.05317072490505 14.818285601642208 -67.03187283353228 -19.01476952180615 22.4905763733386 35.33833686808974 11.79457050704157 -86.70070654092856 25.185781359852896 20.521240128349977 19.814114866123216 78.05531698592385 10.029892443326418 -97.39820288821723 -0.9603735180566161-64.45083314406774 -67.59977551168708 9.37969253153906 -96.6057651764448 11.206098841188833 74.90790318762694 17.531645576460654 21.26591694292548 27.215113714718253-85.31326664717933 10.507088586039371
1 -0.42076958781498162.3850692704428043 -1.3661921493141755-0.57464980721204832.2588675039630703 -5.100101894797036 -0.0005433423021984177-3.0055202143012365 5.749693220009271 11.379708067727588 10.119772822286162 0.15698369211085733-10.937595546203902-31.110839874678003 -5.593388174686233-17.48851742053923519.942127063793418 -0.6804349583522779-19.03708392463745428.74230527011865 12.40206875918237 -9.990549218761593 -5.733244330514869 3.171827795840886 -43.944372783025386 -25.8820588524763123.517534442545183 -25.10463172872150417.068162563601867-26.188188765123446-17.51765346352225 -5.803234686368941 23.37461204071744 85.58386322836444 -24.84250900935848 42.2583557612343 -34.83625774127584447.25447854289113 -5.903960946365425 47.891908734840925 -9.673715993876817 -17.5774774820285274.066254744412671 -51.377913297883865-11.51987006746566810.497653831847085 16.358701536495925 -18.3914825056028029.858101501060483 -39.819369217021595-38.74298336407881 12.412960580526423 -16.79176108824452714.714058887306741 8.607153125744537 -6.384705477156807 -52.8779915958480663.667728062420572 -19.219755720289232 -16.20164176309122
2 -0.50247974091959910.9897062935454243 -1.152229281759237 -1.682033038083704 -4.091345910790923 -4.52742403771885552.129578282936375 10.936320913755608 -1.5695520680947808-6.034199421988269 -28.46431144964817 -15.32129294377632 -8.194011820344523 -16.218630438043398 12.021916867709596-4.908477966578501 -29.56619559878632 7.772108300044394 7.680046493196698 13.815505542053483 3.9208120473170016 47.34661694033482 1.544881077052938 9.440027347582042 18.56198304730558 22.3336072648248 -21.578332510459486-48.93092663572265616.5701671385727 16.656088505245513 19.8406469884787 5.384567961213235 -16.73392428744861614.376438801233908-35.323974854495155-7.411178531711759 -12.19133679331107557.91740496088699 34.873491696833774 88.28464395597479 87.65337555912684 -2.4096431528212445-7.8171455961597385-4.016403896979926 -22.96261029782406 -75.8940296403038 -38.8951677113029 -89.75675908427556 -79.5994302281645 -44.45310265105787 -42.34987503786076 -74.13417710288375 -94.54423466637282 -40.877591489278196-73.38521818144409 -14.487330945685514-6.8530939766408885-10.84894017617582-0.0388656483260952478.63468911909872
3 0.12617606561304665-0.91728226378698231.8277090696240983 -1.8883963021695365-3.26085343817413436.94314682034098 -1.964291832580844 5.476441728997025 5.985807394356193 -4.152754646002149 15.497819324027216 1.9473222994398216 -11.1546653716116812.1502221820849754 7.402217623202724 -20.974198348221123-18.49611969411084 -11.197532751079477-4.167571500828548 -16.7492676033496866.873971547452746 -22.28958212850625421.69520422160094 10.732001896726413 -24.901621899667955 13.663451847361172 40.92498717076184 62.02571061444625 97.46935359691241 1.3197202988059933 -13.355307678605655-59.98623606960067 -15.3460319107594843.85479178918432068.451030763844253 -37.3610034378942059.316605927851759 -15.936791503025487-14.200047091850191-96.04376311885646 6.793212237372706 -89.28406931570937 -6.342536181747704 9.84276729692308 -44.15480258178421 -19.716315609075178-8.963766643638541 13.328160220454095 -81.91979053839731 -58.49057458242536 -63.82740201878286 -78.04284003367316 6.898497938656784 -9.975022259994258 -24.581867540712196-43.13228076360685 5.384602201485904 -5.104240140134616-88.56822933573116 18.63888133757838
4 -1.5391949931048126-0.84243862338608713.808044749153777 -1.15040861016063344.975092670034785 -4.03814322037485956.475255733889277 -8.492789285986634 -0.71070840841147211.9868439665217876 -6.335098977847596 18.156422121050845 -3.9319838484429286-0.303888675665301 -18.038103704497153.6137256391127717 12.72102405166281 6.1797872895139765 -17.965746423694828-6.457595529218324 -11.1195782584740362.124546751440085 2.074247115486158 48.526431477044895 -47.7501423866134 -13.2189838629703170.7076755883915242 21.272708498626173 20.218314701800175-4.052289437744317 -28.29098298558251744.10471192261346 27.505033879695844 28.4585973718932739.564898635025768 -6.2001475733889375-33.28464087248315 13.562356933449957 72.47202649403566 -17.63088820680735222.257347577113283 19.793786901529828 -0.888840951088124115.45297619768772 80.01687713977846 -33.02953241445338 47.36388577265113 47.96488983389095 30.47783230830538 52.702201767487 56.4647664098084 27.388702583308334 47.716980722531005 48.86243093017444 -29.47766470897874 76.66863902366097 23.114022602360667 -3.03590434662457820.751371509793366 25.70018487608435
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
99,995-1.160081518789358 -1.5967802399231468-2.15232040817518 -5.152880656063202 -2.81607683456671464.528707893808043 -9.219048918475725 -4.1152783877843895 15.434762333635224 -8.352240079142867 3.2341379115026694 7.679896402408659 19.99465474797146 -15.987822176846745 17.610005841221454-2.9940634500799996-36.9849615488119246.455731448290355 0.8700910607593357 -4.458798902046075 -8.573291238859795 1.7866347197434056 -5.748202862095839 -78.73536930217278 0.8664468950376607 -31.185290130437014-33.40360643789874548.79496517134476 4.273021608667145 -14.76645480929473223.034033698309216 47.916505903411704 22.82356373157275 58.17074570864146 13.075446180847607 5.357406097709567 19.301741918502767 30.91481630395726 -18.99658045583839429.068050048521297 -11.50032407194181 -94.16793562743486 10.247859328520715 -23.33364253340996864.88951899816107 -5.970342533069689 22.724974186922207 -46.358784230253264-76.06357310802707 36.34299568143191 34.5263251515797 -74.93722963856585 -51.83676476605647 28.086594105181963 -1.148488347990102264.59414944331482 -19.3363913041026487.146369194433403 -94.50249266159257 -11.416642775370095
99,9960.133221661855605742.0608209742055763 2.1641428725239287 -2.450274442812819 0.5729664553821341 11.655164926233269 -9.864613671442203 -4.600216494861485 10.08600220223909 5.916293624542951 14.812935982731668 -6.453293834403917 -11.90549514770099 -3.26727352515574 1.8764801411441934-20.02012175801679 20.579289884690567 -7.95774658159159 -8.387038826710807 -18.0222209635527342.692329970764943 -14.30398788132729721.66822494391352 -15.938191880312708-35.29052532512791 -8.631818482611655 9.787860087044647 -53.67539155301477 -6.29070859522252334.35010506794386 6.565193250636609 -15.486170359730892-3.031599295669413 -1.80098865175289345.55563650252154 -37.38886935392985 68.02203785140463 69.71021558546443 67.33004345391464 38.09747878907309 -15.32336767996999276.84362563371494 -35.79579407415943 -32.88316495646942 -23.620694143487448-90.01728440515039 -24.77449621235016567.92281355721133 30.03415640434173 -29.32574935340052 -21.82606452589530525.41085028514592 70.39416642353444 -29.213531794756513-90.47462518115402 -14.585892147549302-36.17160238891088 -33.2209566185244976.76852716941656 -18.539072237418367
99,9971.011157114782744 -0.80040986269630711.2571486498281934 3.8492594702419245 0.7592605926849842 -4.098302780814329 -1.9485099180060705 16.684513355922583 10.087604365608211 3.7452922672933973 -16.33173839915188 19.92199866574765 6.5771681345498845 -0.3230579773623871714.72654802079624613.583443459677845 -4.952279711617992 17.030998980346084 4.201801219449127 -3.910793205671661441.77733885408281 7.96614686571076 -39.10848664323428 -33.69630280939279 -7.463352385087283 7.458696462843669 -5.883303405785125 6.6310954865277845 -6.552748916196248-9.325031603876797 -11.7337490011325093.627520914240156 18.155090307885395 33.4073875839576 45.52621736035822 -22.938060053594263-27.364572553649534-58.35071648799318 -62.86375816449011 19.272818436422003 47.61050132614527 -11.301762317420524-82.24660966605563 16.961463120018315 13.762199024990316 9.330554417908111 -96.02479832620445 -24.711048464719337-2.078012378653908 -10.604821752483073-11.558372427683931-3.6825332773046875-23.548620629546026-95.72823548883444 15.77594599796893 14.557196623771969 15.812183077424558 -82.30672442508799-8.68501822662248 44.23079310012721
99,9980.9852518578365336 0.8203281912686264 -3.884122502896842 -0.95908400432742780.16746213933285223-0.8886763063332375-16.842052417441188 0.0198139466128886246.1752951086966466 -18.13326524831207 -0.33033598775980267.829297546305325 -10.4252625074002822.7819145440653568 1.158097590630274 30.6780239575918 -23.9448164051634155.6018938249159245 -35.65399756657973 2.673171211427327 -2.90883222148649 -3.59167991497657157.002401397456594 14.353272681106485 -20.458739593063836 -47.09280369705129 25.90478920629466 1.8398979773599367 20.39037292398545 6.635600259567852 21.290136759712006 -30.6802383525156 -32.70023383447721 -28.294300515770139.030591834969087 41.28614556628407 -3.340280013558715 -6.387187312457969 -6.795058954505738 -29.239868647721906-84.84487823247701 21.53413969040578 -9.656174756794805 85.86389211836673 -54.80830511204367 -30.709179188326925-20.51621281362256680.1393974655775 -15.86831043391858 69.46209659371226 66.36652900849339 -25.10453716959171579.18237523289388 -25.577375106247562-30.87284219351464 -56.81179164164408 83.71581743144066 -9.27379265343866519.727630954137673 -85.96069547051928
99,9990.280172477999310550.8792488188373339 -2.611294241397942 -1.271843401381004 -5.583106681289557 2.0063535490559556 8.803561240522425 5.065652252075632 8.014785992140089 2.726435130640515 12.46703945978122 -0.87624409106155750.313008136552742734.259569516217728 -8.76361980315363527.42697941843017 -18.4957182932119153.2235230804059354 19.09973219172654 -21.25726264511826 -10.180990877752983-1.519950417648088522.71070295724785 29.616379288189506 -0.1316424396912179417.225907298944403 5.9791658138855075 11.74845639489894 -4.90066391424355351.065677623825266 -3.7948783924044243-32.70626521313637 -49.77902739808171 -38.9673863548757 4.223577391775786 -26.91850352108989666.81964173436637 76.24293014754961 -31.65153708363635622.893190015052674 -36.482595175686725-25.30090587669703 -10.0417262668186585.274361409552595 -34.88489743571424498.35907785706063 23.57152847224355 26.457155702616525 -86.30659590503936 12.050979659904716 3.057710144296827 -86.50100893855216 23.845662599505307 27.79510549576583 97.55955420927998 -40.44816836188145 2.789198094433643 -4.188993886405869-29.329836024823493 -40.232345894787784

请注意,您需要安装scikit-learn来仅拟合PCAIncremental转换器。transform方法不依赖于安装scikit-learn

随机投影#

随机投影是另一种流行的降维方法,尤其是在数据维度非常高的情况下。vaex.ml 方便地将 scikit-learn.random_projection.GaussianRandomProjectionscikit-learn.random_projection.SparseRandomProjection 封装在一个 vaex.ml 转换器中。

[14]:
rand_proj = vaex.ml.RandomProjections(features=features, n_components=10)
rand_proj.fit(df)
rand_proj.transform(df)
[14]:
# 特征_0 特征_1 特征_2 特征_3 特征_4 特征_5 特征_6 特征_7 特征_8 特征_9 特征_10 特征_11 特征_12 特征_13 特征_14 特征_15 特征_16 特征_17 特征_18 特征_19 特征_20 特征_21 特征_22 特征_23 特征_24 特征_25 特征_26 特征_27 特征_28 特征_29 特征_30 特征_31 特征_32 特征_33 特征_34 特征_35 特征_36 特征_37 特征_38 特征_39 特征_40 特征_41 特征_42 特征_43 特征_44 特征_45 特征_46 特征_47 特征_48 特征_49 随机投影_0 随机投影_1 随机投影_2 随机投影_3 随机投影_4 随机投影_5 随机投影_6 随机投影_7 随机投影_8 随机投影_9
0 0.21916619701436382-1.1435438188965208-2.236473242690611 -8.81728920352771 1.9931414225984159 0.8289809515418928 -7.847441537857684 -5.990636964340006 0.43889103534482576-6.4855757436955965-14.48532696768287113.825392548457543 -5.5661773929038185-3.1816868599382633 27.66565101972783650.541940500115366 16.001390451665785 32.510983357481614 8.342038455860216 -1.7293759207235855-6.451472523437187 22.55340570655327 -2.543125122041264528.75425936065127 -39.487762558467345 -6.871003398404642 11.198673922236354 -86.63832306461876 -7.32368079105989237.35407351193795 23.653897939827836 39.52047029873747 42.79143756690254 -33.3810495394693 33.05317072490505 14.818285601642208 -67.03187283353228 -19.01476952180615 22.4905763733386 35.33833686808974 11.79457050704157 -86.70070654092856 25.185781359852896 20.521240128349977 19.814114866123216 78.05531698592385 10.029892443326418 -97.39820288821723 -0.9603735180566161-64.45083314406774 -50.62485790513975 -8.969974902164104 -75.59787959901278 -32.23015488522056 -8.839635748773595 25.52280920491688 -67.81125847807398 20.625813141370337 -8.9492512335752 -38.397093148408445
1 -0.42076958781498162.3850692704428043 -1.3661921493141755-0.57464980721204832.2588675039630703 -5.100101894797036 -0.0005433423021984177-3.0055202143012365 5.749693220009271 11.379708067727588 10.119772822286162 0.15698369211085733-10.937595546203902-31.110839874678003 -5.593388174686233-17.48851742053923519.942127063793418 -0.6804349583522779-19.03708392463745428.74230527011865 12.40206875918237 -9.990549218761593 -5.733244330514869 3.171827795840886 -43.944372783025386 -25.8820588524763123.517534442545183 -25.10463172872150417.068162563601867-26.188188765123446-17.51765346352225 -5.803234686368941 23.37461204071744 85.58386322836444 -24.84250900935848 42.2583557612343 -34.83625774127584447.25447854289113 -5.903960946365425 47.891908734840925 -9.673715993876817 -17.5774774820285274.066254744412671 -51.377913297883865-11.51987006746566810.497653831847085 16.358701536495925 -18.3914825056028029.858101501060483 -39.819369217021595-24.167592671736728 -83.6194525409906 -31.474566122257382 -53.51874280599636 -9.295953556730474 12.065310248051029 21.935134361477004 -72.0479982398111 -66.96195351258001 76.22398276816658
2 -0.50247974091959910.9897062935454243 -1.152229281759237 -1.682033038083704 -4.091345910790923 -4.52742403771885552.129578282936375 10.936320913755608 -1.5695520680947808-6.034199421988269 -28.46431144964817 -15.32129294377632 -8.194011820344523 -16.218630438043398 12.021916867709596-4.908477966578501 -29.56619559878632 7.772108300044394 7.680046493196698 13.815505542053483 3.9208120473170016 47.34661694033482 1.544881077052938 9.440027347582042 18.56198304730558 22.3336072648248 -21.578332510459486-48.93092663572265616.5701671385727 16.656088505245513 19.8406469884787 5.384567961213235 -16.73392428744861614.376438801233908-35.323974854495155-7.411178531711759 -12.19133679331107557.91740496088699 34.873491696833774 88.28464395597479 87.65337555912684 -2.4096431528212445-7.8171455961597385-4.016403896979926 -22.96261029782406 -75.8940296403038 -38.8951677113029 -89.75675908427556 -79.5994302281645 -44.45310265105787 -30.370561351797924 -69.21024877654797 -131.21336032017504 -23.81397986098913 90.48694640695885 27.981469036784446 -71.13131857248655 -165.47320481693575 30.36401943353085 -37.55586272094929
3 0.12617606561304665-0.91728226378698231.8277090696240983 -1.8883963021695365-3.26085343817413436.94314682034098 -1.964291832580844 5.476441728997025 5.985807394356193 -4.152754646002149 15.497819324027216 1.9473222994398216 -11.1546653716116812.1502221820849754 7.402217623202724 -20.974198348221123-18.49611969411084 -11.197532751079477-4.167571500828548 -16.7492676033496866.873971547452746 -22.28958212850625421.69520422160094 10.732001896726413 -24.901621899667955 13.663451847361172 40.92498717076184 62.02571061444625 97.46935359691241 1.3197202988059933 -13.355307678605655-59.98623606960067 -15.3460319107594843.85479178918432068.451030763844253 -37.3610034378942059.316605927851759 -15.936791503025487-14.200047091850191-96.04376311885646 6.793212237372706 -89.28406931570937 -6.342536181747704 9.84276729692308 -44.15480258178421 -19.716315609075178-8.963766643638541 13.328160220454095 -81.91979053839731 -58.49057458242536 125.12748803342656 -25.206573635553035 61.805492059522535 15.847357808911099 -76.71575173832926 86.50353271166043 86.55719953897724 64.19018426217575 -109.12935339038033 -76.8186950536783
4 -1.5391949931048126-0.84243862338608713.808044749153777 -1.15040861016063344.975092670034785 -4.03814322037485956.475255733889277 -8.492789285986634 -0.71070840841147211.9868439665217876 -6.335098977847596 18.156422121050845 -3.9319838484429286-0.303888675665301 -18.038103704497153.6137256391127717 12.72102405166281 6.1797872895139765 -17.965746423694828-6.457595529218324 -11.1195782584740362.124546751440085 2.074247115486158 48.526431477044895 -47.7501423866134 -13.2189838629703170.7076755883915242 21.272708498626173 20.218314701800175-4.052289437744317 -28.29098298558251744.10471192261346 27.505033879695844 28.4585973718932739.564898635025768 -6.2001475733889375-33.28464087248315 13.562356933449957 72.47202649403566 -17.63088820680735222.257347577113283 19.793786901529828 -0.888840951088124115.45297619768772 80.01687713977846 -33.02953241445338 47.36388577265113 47.96488983389095 30.47783230830538 52.702201767487 9.100443729937155 -98.2487363365348 -86.04861549617408 -10.27966060169664 57.67907962932948 -74.56592607052885 -16.669282052441403 -26.583518157157688 47.49051485779235 178.45202653205695
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
99,995-1.160081518789358 -1.5967802399231468-2.15232040817518 -5.152880656063202 -2.81607683456671464.528707893808043 -9.219048918475725 -4.1152783877843895 15.434762333635224 -8.352240079142867 3.2341379115026694 7.679896402408659 19.99465474797146 -15.987822176846745 17.610005841221454-2.9940634500799996-36.9849615488119246.455731448290355 0.8700910607593357 -4.458798902046075 -8.573291238859795 1.7866347197434056 -5.748202862095839 -78.73536930217278 0.8664468950376607 -31.185290130437014-33.40360643789874548.79496517134476 4.273021608667145 -14.76645480929473223.034033698309216 47.916505903411704 22.82356373157275 58.17074570864146 13.075446180847607 5.357406097709567 19.301741918502767 30.91481630395726 -18.99658045583839429.068050048521297 -11.50032407194181 -94.16793562743486 10.247859328520715 -23.33364253340996864.88951899816107 -5.970342533069689 22.724974186922207 -46.358784230253264-76.06357310802707 36.34299568143191 79.74173570372625 -120.99425995411295 -158.6863110682003 51.08724948440816 45.49604758883528 -92.51884988772696 -33.86586167918684 -110.19228327900962 10.471099356215348 95.03245666604596
99,9960.133221661855605742.0608209742055763 2.1641428725239287 -2.450274442812819 0.5729664553821341 11.655164926233269 -9.864613671442203 -4.600216494861485 10.08600220223909 5.916293624542951 14.812935982731668 -6.453293834403917 -11.90549514770099 -3.26727352515574 1.8764801411441934-20.02012175801679 20.579289884690567 -7.95774658159159 -8.387038826710807 -18.0222209635527342.692329970764943 -14.30398788132729721.66822494391352 -15.938191880312708-35.29052532512791 -8.631818482611655 9.787860087044647 -53.67539155301477 -6.29070859522252334.35010506794386 6.565193250636609 -15.486170359730892-3.031599295669413 -1.80098865175289345.55563650252154 -37.38886935392985 68.02203785140463 69.71021558546443 67.33004345391464 38.09747878907309 -15.32336767996999276.84362563371494 -35.79579407415943 -32.88316495646942 -23.620694143487448-90.01728440515039 -24.77449621235016567.92281355721133 30.03415640434173 -29.32574935340052 12.801266126889404 17.612236115044166 -31.111396519869256 -160.72849754950767 6.480988179687637 4.231265515946373 -52.555790176785194 -65.21246117529064 35.89601203569984 127.45678271483702
99,9971.011157114782744 -0.80040986269630711.2571486498281934 3.8492594702419245 0.7592605926849842 -4.098302780814329 -1.9485099180060705 16.684513355922583 10.087604365608211 3.7452922672933973 -16.33173839915188 19.92199866574765 6.5771681345498845 -0.3230579773623871714.72654802079624613.583443459677845 -4.952279711617992 17.030998980346084 4.201801219449127 -3.910793205671661441.77733885408281 7.96614686571076 -39.10848664323428 -33.69630280939279 -7.463352385087283 7.458696462843669 -5.883303405785125 6.6310954865277845 -6.552748916196248-9.325031603876797 -11.7337490011325093.627520914240156 18.155090307885395 33.4073875839576 45.52621736035822 -22.938060053594263-27.364572553649534-58.35071648799318 -62.86375816449011 19.272818436422003 47.61050132614527 -11.301762317420524-82.24660966605563 16.961463120018315 13.762199024990316 9.330554417908111 -96.02479832620445 -24.711048464719337-2.078012378653908 -10.604821752483073-2.4863267734391865 -10.434958342024952 -37.55392055999496 6.171867513827003 -29.256283776632728 -72.71591584878013 40.24611847925469 -102.31580552627864 -14.905953231227388 -11.740055851590997
99,9980.9852518578365336 0.8203281912686264 -3.884122502896842 -0.95908400432742780.16746213933285223-0.8886763063332375-16.842052417441188 0.0198139466128886246.1752951086966466 -18.13326524831207 -0.33033598775980267.829297546305325 -10.4252625074002822.7819145440653568 1.158097590630274 30.6780239575918 -23.9448164051634155.6018938249159245 -35.65399756657973 2.673171211427327 -2.90883222148649 -3.59167991497657157.002401397456594 14.353272681106485 -20.458739593063836 -47.09280369705129 25.90478920629466 1.8398979773599367 20.39037292398545 6.635600259567852 21.290136759712006 -30.6802383525156 -32.70023383447721 -28.294300515770139.030591834969087 41.28614556628407 -3.340280013558715 -6.387187312457969 -6.795058954505738 -29.239868647721906-84.84487823247701 21.53413969040578 -9.656174756794805 85.86389211836673 -54.80830511204367 -30.709179188326925-20.51621281362256680.1393974655775 -15.86831043391858 69.46209659371226 -70.00012029923253 198.0368255008663 129.3714720510582 30.652606384505287 -65.3920698996377 49.51640293990293 11.882703005485045 93.26651618256129 35.206089617027985 -61.77494520916369
99,9990.280172477999310550.8792488188373339 -2.611294241397942 -1.271843401381004 -5.583106681289557 2.0063535490559556 8.803561240522425 5.065652252075632 8.014785992140089 2.726435130640515 12.46703945978122 -0.87624409106155750.313008136552742734.259569516217728 -8.76361980315363527.42697941843017 -18.4957182932119153.2235230804059354 19.09973219172654 -21.25726264511826 -10.180990877752983-1.519950417648088522.71070295724785 29.616379288189506 -0.1316424396912179417.225907298944403 5.9791658138855075 11.74845639489894 -4.90066391424355351.065677623825266 -3.7948783924044243-32.70626521313637 -49.77902739808171 -38.9673863548757 4.223577391775786 -26.91850352108989666.81964173436637 76.24293014754961 -31.65153708363635622.893190015052674 -36.482595175686725-25.30090587669703 -10.0417262668186585.274361409552595 -34.88489743571424498.35907785706063 23.57152847224355 26.457155702616525 -86.30659590503936 12.050979659904716 45.50866581430373 33.59123204918983 66.48747993035953 93.58220327847411 -113.34727146050997 34.20894130389669 94.5050429333418 98.6447663145478 -42.700555543235716 -3.632586769281134

聚类#

K-Means#

vaex.ml 实现了一个快速且可扩展的K-Means聚类算法。其用法与 scikit-learn 类似。

[15]:
import vaex.ml.cluster

df = vaex.datasets.iris()

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
kmeans = vaex.ml.cluster.KMeans(features=features, n_clusters=3, max_iter=100, verbose=True, random_state=42)
kmeans.fit(df)

df_trans = kmeans.transform(df)
df_trans
Iteration    0, inertia  519.0500000000001
Iteration    1, inertia  156.70447116074328
Iteration    2, inertia  88.70688235734133
Iteration    3, inertia  80.23054939305554
Iteration    4, inertia  79.28654263977778
Iteration    5, inertia  78.94084142614601
Iteration    6, inertia  78.94084142614601
[15]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 kmeans预测
0 5.9 3.0 4.2 1.5 1 0
1 6.1 3.0 4.6 1.4 1 0
2 6.6 2.9 4.6 1.3 1 0
3 6.7 3.3 5.7 2.1 2 1
4 5.5 4.2 1.4 0.2 0 2
... ... ... ... ... ... ...
1455.2 3.4 1.4 0.2 0 2
1465.1 3.8 1.6 0.2 0 2
1475.8 2.6 4.0 1.2 1 0
1485.7 3.8 1.7 0.3 0 2
1496.2 2.9 4.3 1.3 1 0

K-Means 是一种无监督算法,这意味着在转换后的数据集中预测的聚类标签不一定与类别标签相对应。我们可以将预测的聚类标识符映射到类别标签,以便更容易构建诊断图。

[16]:
df_trans['predicted_kmean_map'] = df_trans.prediction_kmeans.map(mapper={0: 1, 1: 2, 2: 0})
df_trans
[16]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 K均值预测 预测的K均值映射
0 5.9 3.0 4.2 1.5 1 0 1
1 6.1 3.0 4.6 1.4 1 0 1
2 6.6 2.9 4.6 1.3 1 0 1
3 6.7 3.3 5.7 2.1 2 1 2
4 5.5 4.2 1.4 0.2 0 2 0
... ... ... ... ... ... ... ...
1455.2 3.4 1.4 0.2 0 2 0
1465.1 3.8 1.6 0.2 0 2 0
1475.8 2.6 4.0 1.2 1 0 1
1485.7 3.8 1.7 0.3 0 2 0
1496.2 2.9 4.3 1.3 1 0 1

现在我们可以构建简单的散点图,并且可以看到在鸢尾花数据集的情况下,K-Means 在将数据分成 3 类方面做得相当不错。

[17]:
fig = plt.figure(figsize=(12, 5))

plt.subplot(121)
df_trans.scatter(df_trans.petal_length, df_trans.petal_width, c_expr=df_trans.class_)
plt.title('Original classes')

plt.subplot(122)
df_trans.scatter(df_trans.petal_length, df_trans.petal_width, c_expr=df_trans.predicted_kmean_map)
plt.title('Predicted classes')

plt.tight_layout()
plt.show()
/home/jovan/vaex/packages/vaex-core/vaex/viz/mpl.py:205: UserWarning: `scatter` is deprecated and it will be removed in version 5.x. Please use `df.viz.scatter` instead.
  warnings.warn('`scatter` is deprecated and it will be removed in version 5.x. Please use `df.viz.scatter` instead.')
_images/tutorial_ml_35_1.png

与在vaex.ml中实现的任何算法一样,K-Means可以用于数十亿个样本。当应用于过采样的Iris数据集时,拟合时间不到2分钟,样本数量超过10亿

[18]:
df = vaex.datasets.iris_1e9()
n_samples = len(df)
print(f'Number of samples in DataFrame: {n_samples:,}')
Number of samples in DataFrame: 1,005,000,000
[19]:
%%time

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
kmeans = vaex.ml.cluster.KMeans(features=features, n_clusters=3, max_iter=100, verbose=True, random_state=31)
kmeans.fit(df)
Iteration    0, inertia  838974000.0037192
Iteration    1, inertia  535903134.000306
Iteration    2, inertia  530190921.4848897
Iteration    3, inertia  528931941.03372437
Iteration    4, inertia  528931941.0337243
CPU times: user 2min 37s, sys: 1.26 s, total: 2min 39s
Wall time: 19.9 s

监督学习#

虽然 vaex.ml 尚未实现任何监督机器学习模型,但它确实提供了对多个流行库的封装,例如 scikit-learnXGBoostLightGBMCatBoost

这些包装器的主要好处是它们将模型转换为vaex.ml转换器。这意味着模型成为DataFrame状态的一部分,因此可以被序列化,并且它们的预测可以作为虚拟列返回。这对于创建各种诊断图和无内存成本地评估性能指标,以及构建集成模型特别有用。

Scikit-Learn 示例#

vaex.ml.sklearn 模块提供了对 scikit-learn 估计器的便捷封装。实际上,这些封装可以用于任何遵循 scikit-learn 建立的 API 约定的库,即实现了 .fit.transform 方法的库。

这是一个示例:

[20]:
from vaex.ml.sklearn import Predictor
from sklearn.ensemble import GradientBoostingClassifier

df = vaex.datasets.iris()

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

model = GradientBoostingClassifier(random_state=42)
vaex_model = Predictor(features=features, target=target, model=model, prediction_name='prediction')

vaex_model.fit(df=df)

df = vaex_model.transform(df)
df
[20]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 预测
0 5.9 3.0 4.2 1.5 1 1
1 6.1 3.0 4.6 1.4 1 1
2 6.6 2.9 4.6 1.3 1 1
3 6.7 3.3 5.7 2.1 2 2
4 5.5 4.2 1.4 0.2 0 0
... ... ... ... ... ... ...
1455.2 3.4 1.4 0.2 0 0
1465.1 3.8 1.6 0.2 0 0
1475.8 2.6 4.0 1.2 1 1
1485.7 3.8 1.7 0.3 0 0
1496.2 2.9 4.3 1.3 1 1

人们仍然可以通过利用scikit-learn提供的在线学习器来训练预测模型,即使数据集太大无法放入内存。vaex.ml.sklearn.IncrementalPredictor方便地包装了这些学习器,并提供了从vaex DataFrame传递数据给它们的控制方式。

让我们在包含超过10亿个样本的过采样Iris数据集上训练一个模型。

[21]:
from vaex.ml.sklearn import IncrementalPredictor
from sklearn.linear_model import SGDClassifier

df = vaex.datasets.iris_1e9()

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

model = SGDClassifier(learning_rate='constant', eta0=0.0001, random_state=42)
vaex_model = IncrementalPredictor(features=features, target=target, model=model,
                                  batch_size=500_000, partial_fit_kwargs={'classes':[0, 1, 2]})

vaex_model.fit(df=df, progress='widget')

df = vaex_model.transform(df)
df
[21]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 预测
0 5.9 3.0 4.2 1.5 1 1
1 6.1 3.0 4.6 1.4 1 1
2 6.6 2.9 4.6 1.3 1 1
3 6.7 3.3 5.7 2.1 2 2
4 5.5 4.2 1.4 0.2 0 0
... ... ... ... ... ... ...
1,004,999,9955.2 3.4 1.4 0.2 0 0
1,004,999,9965.1 3.8 1.6 0.2 0 0
1,004,999,9975.8 2.6 4.0 1.2 1 1
1,004,999,9985.7 3.8 1.7 0.3 0 0
1,004,999,9996.2 2.9 4.3 1.3 1 1

XGBoost 示例#

诸如XGBoost之类的库提供了更多选项,例如训练期间的验证和提前停止。我们提供了包装器,除了scikit-learn API外,还保持与这些库的原生API接近。

虽然以下示例展示了XGBoost包装器,但vaex.ml也为LightGBMCatBoost实现了类似的包装器。

[22]:
from vaex.ml.xgboost import XGBoostModel

df = vaex.datasets.iris_1e5()
df_train, df_test = df.ml.train_test_split(test_size=0.2, verbose=False)

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

params = {'learning_rate': 0.1,
          'max_depth': 3,
          'num_class': 3,
          'objective': 'multi:softmax',
          'subsample': 1,
          'random_state': 42,
          'n_jobs': -1}


booster = XGBoostModel(features=features, target=target, num_boost_round=500, params=params)
booster.fit(df=df_train, evals=[(df_train, 'train'), (df_test, 'test')], early_stopping_rounds=5)

df_test = booster.transform(df_train)
df_test
[13:41:31] WARNING: /home/conda/feedstock_root/build_artifacts/xgboost_1607604574104/work/src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'multi:softmax' was changed from 'merror' to 'mlogloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
[22]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 xgboost预测
0 5.9 3.0 4.2 1.5 1 1.0
1 6.1 3.0 4.6 1.4 1 1.0
2 6.6 2.9 4.6 1.3 1 1.0
3 6.7 3.3 5.7 2.1 2 2.0
4 5.5 4.2 1.4 0.2 0 0.0
... ... ... ... ... ... ...
80,3955.2 3.4 1.4 0.2 0 0.0
80,3965.1 3.8 1.6 0.2 0 0.0
80,3975.8 2.6 4.0 1.2 1 1.0
80,3985.7 3.8 1.7 0.3 0 0.0
80,3996.2 2.9 4.3 1.3 1 1.0

CatBoost 示例#

CatBoost库支持模型求和。通过这一功能,我们可以使用CatBoost来训练一个模型,即使数据太大无法放入内存。其思路是每块数据训练一个CatBoost模型,然后将这些单独的模型求和以创建一个主模型。要通过vaex.ml使用此功能,只需在CatBoostModel包装器中指定batch_size参数。还可以指定其他选项,例如如何求和各个模型的策略,或者它们应该如何加权。

[23]:
from vaex.ml.catboost import CatBoostModel

df = vaex.datasets.iris_1e8()
df_train, df_test = df.ml.train_test_split(test_size=0.2, verbose=False)

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

params = {
    'leaf_estimation_method': 'Gradient',
    'learning_rate': 0.1,
    'max_depth': 3,
    'bootstrap_type': 'Bernoulli',
    'subsample': 0.8,
    'sampling_frequency': 'PerTree',
    'colsample_bylevel': 0.8,
    'reg_lambda': 1,
    'objective': 'MultiClass',
    'eval_metric': 'MultiClass',
    'random_state': 42,
    'verbose': 0,
}

booster = CatBoostModel(features=features, target=target, num_boost_round=23,
                        params=params, prediction_type='Class', batch_size=11_000_000)
booster.fit(df=df_train, progress='widget')

df_test = booster.transform(df_train)
df_test
[23]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 catboost预测
0 5.9 3.0 4.2 1.5 1 array([1])
1 6.1 3.0 4.6 1.4 1 array([1])
2 6.6 2.9 4.6 1.3 1 array([1])
3 6.7 3.3 5.7 2.1 2 array([2])
4 5.5 4.2 1.4 0.2 0 array([0])
... ... ... ... ... ... ...
80,399,9955.2 3.4 1.4 0.2 0 array([0])
80,399,9965.1 3.8 1.6 0.2 0 array([0])
80,399,9975.8 2.6 4.0 1.2 1 array([1])
80,399,9985.7 3.8 1.7 0.3 0 array([0])
80,399,9996.2 2.9 4.3 1.3 1 array([1])

Keras 示例#

Keras 是最受欢迎的高级API,用于使用tensorflow作为其后端构建神经网络模型。神经网络可以具有非常多样化和复杂的架构,其训练循环既可以是简单的,也可以是复杂的。这就是为什么,至少目前,我们让用户像平常一样训练他们的keras模型,并在vaex-ml中提供了一个简单的包装器用于这些模型的序列化和惰性评估。此外,vaex-ml还提供了一个方便的方法,将DataFrame转换为生成器,适用于Keras模型的训练。请参见下面的示例。

[24]:
import vaex.ml.tensorflow
import tensorflow.keras as K

df = vaex.example()
df_train, df_valid, df_test = df.split_random([0.8, 0.1, 0.1], random_state=42)

features = ['x', 'y', 'z', 'vx', 'vy', 'vz']
target = 'FeH'

# Scaling the features
df_train = df_train.ml.minmax_scaler(features=features)
features = df_train.get_column_names(regex='^minmax_')

# Apply preprocessing to the validation
state_prep = df_train.state_get()
df_valid.state_set(state_prep)

# Generators for the train and validation sets
gen_train = df_train.ml.tensorflow.to_keras_generator(features=features, target=target, batch_size=512)
gen_valid = df_valid.ml.tensorflow.to_keras_generator(features=features, target=target, batch_size=512)

# Create and fit a simple Sequential Keras model
nn_model = K.Sequential()
nn_model.add(K.layers.Dense(3, activation='tanh'))
nn_model.add(K.layers.Dense(1, activation='linear'))
nn_model.compile(optimizer='sgd', loss='mse')
nn_model.fit(x=gen_train, validation_data=gen_valid, epochs=11, steps_per_epoch=516, validation_steps=65)

# Serialize the model
keras_model = vaex.ml.tensorflow.KerasModel(features=features, prediction_name='keras_pred', model=nn_model)
df_train = keras_model.transform(df_train)

# Apply all the transformations to the test set
state = df_train.state_get()
df_test.state_set(state)

# Preview the results
df_test.head(5)
2021-08-14 23:47:55.800260: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-08-14 23:47:55.800282: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
Recommended "steps_per_epoch" arg: 516.0
Recommended "steps_per_epoch" arg: 65.0
2021-08-14 23:47:57.111408: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-14 23:47:57.111910: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.111974: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112032: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112093: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112150: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcurand.so.10'; dlerror: libcurand.so.10: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112206: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112261: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112317: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2021-08-14 23:47:57.112327: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2021-08-14 23:47:57.112682: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 1/11
 11/516 [..............................] - ETA: 2s - loss: 1.7922
2021-08-14 23:47:57.326751: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
516/516 [==============================] - 3s 6ms/step - loss: 0.2172 - val_loss: 0.1724
Epoch 2/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1736 - val_loss: 0.1715
Epoch 3/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1729 - val_loss: 0.1705
Epoch 4/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1725 - val_loss: 0.1707
Epoch 5/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1722 - val_loss: 0.1708
Epoch 6/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1720 - val_loss: 0.1701
Epoch 7/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1718 - val_loss: 0.1697
Epoch 8/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1717 - val_loss: 0.1706
Epoch 9/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1715 - val_loss: 0.1698
Epoch 10/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1714 - val_loss: 0.1702
Epoch 11/11
516/516 [==============================] - 3s 6ms/step - loss: 0.1713 - val_loss: 0.1701
INFO:tensorflow:Assets written to: /tmp/tmp14gsptzz/assets
2021-08-14 23:48:31.519641: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
[24]:
# id x y z vx vy vz E L Lz FeH minmax_scaled_x minmax_scaled_y minmax_scaled_z minmax_scaled_vx minmax_scaled_vy minmax_scaled_vzkeras_pred
0 23 0.137403-5.07974 1.40165 111.828 62.8776 -88.121 -134786 700.236 576.698-1.7935 0.375163 0.72055 0.397008 0.570648 0.56065 0.414253array([-1.6143968], dtype=float32)
1 31-1.95543 -0.840676 1.26239 -259.282 20.8279-148.457 -134990 676.813-258.7 -0.623007 0.365132 0.738746 0.395427 0.266912 0.5249 0.357964array([-1.509573], dtype=float32)
2 22 2.33077 -0.570014 0.761285 -53.4566-43.377 -71.3196-177062 196.209-131.573-0.889463 0.385676 0.739908 0.389737 0.43537 0.470313 0.429927array([-1.5752358], dtype=float32)
3 26 0.777881-2.83258 0.0797214 256.427 202.451 -12.76 -125176 884.581 883.833-1.65996 0.378233 0.730196 0.381998 0.688994 0.679314 0.484558array([-1.6558373], dtype=float32)
4 1 3.37429 2.62885 -0.797169 300.697 153.772 83.9173 -97150.4681.868-271.616-1.6496 0.390678 0.753639 0.372041 0.725228 0.637928 0.574749array([-1.6719546], dtype=float32)

River 示例#

River 是一个新兴的在线学习库,提供了多种可以增量学习的模型。虽然目前大多数 river 模型支持逐样本训练,但很少有支持小批量训练的模型,而小批量训练速度极快——这是与 vaex 进行机器学习的绝佳协同。

[25]:
from vaex.ml.incubator.river import RiverModel
from river.linear_model import LinearRegression
from river import optim


df = vaex.datasets.iris_1e9()
df_train, df_test = df.ml.train_test_split(test_size=0.2, verbose=False)

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

river_model = RiverModel(features=features,
                         target=target,
                         model=LinearRegression(optimizer=optim.SGD(0.001), intercept_lr=0.001),
                         prediction_name='prediction_raw',
                         batch_size=500_000)
river_model.fit(df_train, progress='widget')
river_model.transform(df_test)
[25]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 原始预测
0 5.9 3.0 4.2 1.5 1 1.2262451850482554
1 6.1 3.0 4.6 1.4 1 1.3372106202149072
2 6.6 2.9 4.6 1.3 1 1.3080263625894342
3 6.7 3.3 5.7 2.1 2 1.8246442870772779
4 5.5 4.2 1.4 0.2 0 -0.1719159051653813
... ... ... ... ... ... ...
200,999,9955.2 3.4 1.4 0.2 0 -0.06961837848289065
200,999,9965.1 3.8 1.6 0.2 0 -0.04133966888449841
200,999,9975.8 2.6 4.0 1.2 1 1.1380612859534056
200,999,9985.7 3.8 1.7 0.3 0 -0.005633275295105093
200,999,9996.2 2.9 4.3 1.3 1 1.2171097577656713

指标#

vaex-ml 还提供了几种最常见的分类和回归任务的评估指标。这些指标在 vaex-ml 中实现,因此是外核评估的,所以你不需要具体化目标和预测列。

以下是当前支持的指标列表:

  • 分类(二分类,以及多分类问题的宏平均):

    • 准确率

    • 精确率

    • 召回率

    • F1分数

    • 混淆矩阵

    • 分类报告(一种便捷方法,同时打印出准确率、精确率、召回率和F1分数)

    • 马修斯相关系数

  • 回归

    • 平均绝对误差

    • 均方误差

    • R2相关系数

这是一个简单的例子:

[26]:
import vaex.ml.metrics
from sklearn.linear_model import LogisticRegression

df = vaex.datasets.iris()
df_train, df_test = df.split_random([0.8, 0.2], random_state=55)

features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']
target = 'class_'

model = LogisticRegression(random_state=42)
vaex_model = Predictor(features=features, target=target, model=model, prediction_name='pred')

vaex_model.fit(df=df_train)

df_test = vaex_model.transform(df_test)

print(df_test.ml.metrics.classification_report(df_test.class_, df_test.pred, average='macro'))

        Classification report:

        Accuracy:  0.933
        Precision: 0.928
        Recall:    0.928
        F1:        0.928

/home/jovan/vaex/packages/vaex-core/vaex/dataframe.py:5516: UserWarning: It seems your column class_ is already ordinal encoded (values between 0 and 2), automatically switching to use df.categorize
  warnings.warn(f'It seems your column {column} is already ordinal encoded (values between {min_value} and {max_value}), automatically switching to use df.categorize')
/home/jovan/vaex/packages/vaex-core/vaex/dataframe.py:5516: UserWarning: It seems your column pred is already ordinal encoded (values between 0 and 2), automatically switching to use df.categorize
  warnings.warn(f'It seems your column {column} is already ordinal encoded (values between {min_value} and {max_value}), automatically switching to use df.categorize')

状态转移 - 简化管道#

每个 vaex DataFrame 由两部分组成:数据状态数据 是不可变的,任何操作如过滤、添加新列或应用转换器或预测模型都只会修改 状态。这是一个非常强大的概念,可以完全重新定义我们对机器学习管道的想象。

作为一个例子,让我们再次基于鸢尾花数据集创建一个模型。在这里,我们将创建几个新特征,进行PCA转换,最后训练一个预测模型。

[27]:
# Load data and split it in train and test sets
df = vaex.datasets.iris()
df_train, df_test = df.ml.train_test_split(test_size=0.2, verbose=False)

# Create new features
df_train['petal_ratio'] = df_train.petal_length / df_train.petal_width
df_train['sepal_ratio'] = df_train.sepal_length / df_train.sepal_width

# Do a PCA transformation
features = ['petal_length', 'petal_width', 'sepal_length', 'sepal_width', 'petal_ratio', 'sepal_ratio']
pca = vaex.ml.PCA(features=features, n_components=6)
df_train = pca.fit_transform(df_train)

# Display the training DataFrame at this stage
df_train
[27]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 花瓣比例 花萼比例 PCA_0 PCA_1 PCA_2 PCA_3 PCA_4 PCA_5
0 5.4 3.0 4.5 1.5 1 3.0 1.8 -1.510547480171215 0.3611524321126822 -0.4005106138591812 0.5491844107628985 0.21135370342329635 -0.009542243224854377
1 4.8 3.4 1.6 0.2 0 8.0 1.411764705882353 4.447550641536847 0.2799644730487585 -0.04904458661276928 0.18719360579644695 0.10928493945448532 0.005228919010020094
2 6.9 3.1 4.9 1.5 1 3.266666666666667 2.2258064516129035-1.777649528149752 -0.60828897708458910.48007833550651513 -0.377620118668313350.05174472701894024 -0.04673816474220924
3 4.4 3.2 1.3 0.2 0 6.5 1.375 3.400548263702555 1.437036928591846 -0.3662652846960042 0.23420836198441913 0.05750021481634099 -0.023055011653267066
4 5.6 2.8 4.9 2.0 2 2.45 2.0 -2.32450987662220940.14710673877401348-0.5150809942258257 0.5471824391426298 -0.12154714382375817 0.0044686197532133876
... ... ... ... ... ... ... ... ... ... ... ... ... ...
1155.2 3.4 1.4 0.2 0 6.999999999999999 1.52941176470588253.623794583238953 0.8255759252729563 0.23453320686724874 -0.17599408825208826-0.04687036865354327 -0.02424621891240747
1165.1 3.8 1.6 0.2 0 8.0 1.34210526315789474.42115266246093 0.222875055336637040.4450642830179705 0.2184424557783562 0.14504752606375293 0.07229123907677276
1175.8 2.6 4.0 1.2 1 3.33333333333333352.230769230769231 -1.069062832993727 0.3874258314654399 -0.4471767749236783 -0.2956609879568117 -0.0010695982441835394-0.0065225306610744715
1185.7 3.8 1.7 0.3 0 5.666666666666667 1.50000000000000022.2846521048417037 1.1920826609681359 0.8273738848637026 -0.210489464627257370.03381892388998425 0.018792165273013528
1196.2 2.9 4.3 1.3 1 3.30769230769230752.137931034482759 -1.29882299587484520.06960434514054464-0.0012167985718341268-0.240722552191808830.05282732890885841 -0.032459999314411514

此时,我们已准备好训练一个预测模型。在这个例子中,让我们使用LightGBM及其scikit-learn API。

[28]:
import lightgbm

features = df_train.get_column_names(regex='^PCA')

booster = lightgbm.LGBMClassifier()

vaex_model = Predictor(model=booster, features=features, target='class_')

vaex_model.fit(df=df_train)
df_train = vaex_model.transform(df_train)

df_train
[28]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 花瓣比例 花萼比例 PCA_0 PCA_1 PCA_2 PCA_3 PCA_4 PCA_5 预测
0 5.4 3.0 4.5 1.5 1 3.0 1.8 -1.510547480171215 0.3611524321126822 -0.4005106138591812 0.5491844107628985 0.21135370342329635 -0.009542243224854377 1
1 4.8 3.4 1.6 0.2 0 8.0 1.411764705882353 4.447550641536847 0.2799644730487585 -0.04904458661276928 0.18719360579644695 0.10928493945448532 0.005228919010020094 0
2 6.9 3.1 4.9 1.5 1 3.266666666666667 2.2258064516129035-1.777649528149752 -0.60828897708458910.48007833550651513 -0.377620118668313350.05174472701894024 -0.04673816474220924 1
3 4.4 3.2 1.3 0.2 0 6.5 1.375 3.400548263702555 1.437036928591846 -0.3662652846960042 0.23420836198441913 0.05750021481634099 -0.023055011653267066 0
4 5.6 2.8 4.9 2.0 2 2.45 2.0 -2.32450987662220940.14710673877401348-0.5150809942258257 0.5471824391426298 -0.12154714382375817 0.0044686197532133876 2
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1155.2 3.4 1.4 0.2 0 6.999999999999999 1.52941176470588253.623794583238953 0.8255759252729563 0.23453320686724874 -0.17599408825208826-0.04687036865354327 -0.02424621891240747 0
1165.1 3.8 1.6 0.2 0 8.0 1.34210526315789474.42115266246093 0.222875055336637040.4450642830179705 0.2184424557783562 0.14504752606375293 0.07229123907677276 0
1175.8 2.6 4.0 1.2 1 3.33333333333333352.230769230769231 -1.069062832993727 0.3874258314654399 -0.4471767749236783 -0.2956609879568117 -0.0010695982441835394-0.00652253066107447151
1185.7 3.8 1.7 0.3 0 5.666666666666667 1.50000000000000022.2846521048417037 1.1920826609681359 0.8273738848637026 -0.210489464627257370.03381892388998425 0.018792165273013528 0
1196.2 2.9 4.3 1.3 1 3.30769230769230752.137931034482759 -1.29882299587484520.06960434514054464-0.0012167985718341268-0.240722552191808830.05282732890885841 -0.032459999314411514 1

最终的 df_train DataFrame 包含了我们创建的所有特征,包括最后的预测结果。现在,我们希望将相同的转换应用到测试集上。我们只需要简单地从 df_train 中提取 state 并将其应用到 df_test 上。这将把对训练集所做的所有更改传播到测试集上。

[29]:
state = df_train.state_get()

df_test.state_set(state)
df_test
[29]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 花瓣比例 花萼比例 PCA_0 PCA_1 PCA_2 PCA_3 PCA_4 PCA_5 预测
0 5.9 3.0 4.2 1.5 1 2.80000000000000031.9666666666666668-1.642627940409072 0.49931302910747727 -0.063088008066644660.10842057110641677 -0.03924298664189224-0.0273944397002728221
1 6.1 3.0 4.6 1.4 1 3.28571428571428562.033333333333333 -1.445047446393471 -0.1019091578746504 -0.018990122394938010.0209807676460904080.1614215276667148 -0.02716639637934938 1
2 6.6 2.9 4.6 1.3 1 3.538461538461538 2.2758620689655173-1.330564613235537 -0.419784747491312670.1759590589290671 -0.4631301992308477 0.08304243689815374 -0.0333517336774292741
3 6.7 3.3 5.7 2.1 2 2.71428571428571442.0303030303030303-2.6719170661531013-0.9149428897499291 0.4156162725009377 0.34633692661436644 0.03742964707590906 -0.0132542861962457742
4 5.5 4.2 1.4 0.2 0 6.999999999999999 1.30952380952380953.6322930267831404 0.8198526437905096 1.046277579362938 0.09738737839850209 0.09412658096734221 0.1329137026697501 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
255.5 2.5 4.0 1.3 1 3.07692307692307662.2 -1.25231200886008960.5975071562677784 -0.7019801415469216 -0.11489031841855571-0.036159457820878690.005496321827264977 1
265.8 2.7 3.9 1.2 1 3.25 2.148148148148148 -1.07923521659046570.5236883751378523 -0.34037717939532286-0.23743695029955128-0.00936891422024664-0.02184110533380834 1
274.4 2.9 1.4 0.2 0 6.999999999999999 1.517241379310345 3.7422969192506095 1.048460304741977 -0.636475521315278 0.07623157913054074 0.004215355833312173-0.06354157393133958 0
284.5 2.3 1.3 0.3 0 4.333333333333334 1.956521739130435 1.4537380535696471 2.4197864889383505 -1.0301500321688102 -0.5150263062576134 -0.2631218962099228 -0.06608059456656257 0
296.9 3.2 5.7 2.3 2 2.47826086956521772.15625 -2.963110301521378 -0.924626055589704 0.44833006106219797 0.20994670504662372 -0.2012725506779131 -0.0189004142877193532

就这样,df_test 包含了我们在训练集上建模的所有列、转换和预测。状态可以轻松地以JSON文件的形式序列化到磁盘。这使得机器学习模型的部署变得非常简单,就像将一个JSON文件从一个环境复制到另一个环境一样。

[30]:
df_train.state_write('./iris_model.json')

df_test.state_load('./iris_model.json')
df_test
[30]:
# 花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别 花瓣比例 花萼比例 PCA_0 PCA_1 PCA_2 PCA_3 PCA_4 PCA_5 预测
0 5.9 3.0 4.2 1.5 1 2.80000000000000031.9666666666666668-1.642627940409072 0.49931302910747727 -0.063088008066644660.10842057110641677 -0.03924298664189224-0.0273944397002728221
1 6.1 3.0 4.6 1.4 1 3.28571428571428562.033333333333333 -1.445047446393471 -0.1019091578746504 -0.018990122394938010.0209807676460904080.1614215276667148 -0.02716639637934938 1
2 6.6 2.9 4.6 1.3 1 3.538461538461538 2.2758620689655173-1.330564613235537 -0.419784747491312670.1759590589290671 -0.4631301992308477 0.08304243689815374 -0.0333517336774292741
3 6.7 3.3 5.7 2.1 2 2.71428571428571442.0303030303030303-2.6719170661531013-0.9149428897499291 0.4156162725009377 0.34633692661436644 0.03742964707590906 -0.0132542861962457742
4 5.5 4.2 1.4 0.2 0 6.999999999999999 1.30952380952380953.6322930267831404 0.8198526437905096 1.046277579362938 0.09738737839850209 0.09412658096734221 0.1329137026697501 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
255.5 2.5 4.0 1.3 1 3.07692307692307662.2 -1.25231200886008960.5975071562677784 -0.7019801415469216 -0.11489031841855571-0.036159457820878690.005496321827264977 1
265.8 2.7 3.9 1.2 1 3.25 2.148148148148148 -1.07923521659046570.5236883751378523 -0.34037717939532286-0.23743695029955128-0.00936891422024664-0.02184110533380834 1
274.4 2.9 1.4 0.2 0 6.999999999999999 1.517241379310345 3.7422969192506095 1.048460304741977 -0.636475521315278 0.07623157913054074 0.004215355833312173-0.06354157393133958 0
284.5 2.3 1.3 0.3 0 4.333333333333334 1.956521739130435 1.4537380535696471 2.4197864889383505 -1.0301500321688102 -0.5150263062576134 -0.2631218962099228 -0.06608059456656257 0
296.9 3.2 5.7 2.3 2 2.47826086956521772.15625 -2.963110301521378 -0.924626055589704 0.44833006106219797 0.20994670504662372 -0.2012725506779131 -0.0189004142877193532