如何使用AlignedUMAP
import numpy as np
import sklearn.datasets
import umap
import umap.plot
import umap.utils as utils
import umap.aligned_umap
import matplotlib.pyplot as plt
在我们的演示中,我们将仅使用来自sklearn的pendigits数据集。
digits = sklearn.datasets.load_digits()
为了创建一系列数据集,每个不同数据集之间有一些共享点,我们首先会对数据进行排序,以便有一个大致合理的进展。在这种情况下,我们将根据手写数字中的“墨水”总量进行排序。这并不是为了有意义,而是为了提供一些有用的东西,以便将其分割成重叠的块,我们希望这些块能够单独嵌入,同时保持对齐。
ordered_digits = digits.data[np.argsort(digits.data.sum(axis=1))]
ordered_target = digits.target[np.argsort(digits.data.sum(axis=1))]
plt.matshow(ordered_digits[-1].reshape((8,8)))
然后我们可以将数据集分成400个样本的切片,以150个样本的步长移动,以确保连续切片之间有重叠。这将给我们提供一个包含十个不同数据集的列表,我们可以对其进行嵌入,目标是确保嵌入中点的位置相对一致。
slices = [ordered_digits[150 * i:min(ordered_digits.shape[0], 150 * i + 400)] for i in range(10)]
为了确保一致性,AlignedUMAP 需要比仅仅数据集更多的信息——我们还需要一些关于数据集之间如何相互关联的信息。这些信息以字典的形式存在,将一个数据集的索引与另一个数据集的索引关联起来。目前,AlignedUMAP 仅支持具有序列中每对连续数据集之间关系的数据集序列。为了构建这个数据集的关系,我们注意到一个数据集的最后250个样本将与下一个数据集的前250个样本相同——这使得构建字典变得容易:它是映射
150 --> 0
151 --> 1
...
398 --> 248
399 --> 249
我们可以使用字典推导式轻松构建。我们将在每对连续对之间具有相同的关系,因此为了制作对之间的关系列表,我们只需将构建的关系复制所需的次数。
relation_dict = {i+150:i for i in range(400-150)}
relation_dicts = [relation_dict.copy() for i in range(len(slices) - 1)]
现在我们有了数据切片的列表和连续对之间的关系列表,我们可以使用AlignedUMAP类来生成嵌入列表。AlignedUMAP类接受UMAP接受的大多数参数。主要的区别在于fit方法需要一个数据集的列表,以及一个关键字参数relations,该参数指定了连续对数据集之间的关系字典。除此之外,基本上都是自动完成的。
%%time
aligned_mapper = umap.AlignedUMAP().fit(slices, relations=relation_dicts)
CPU times: user 57.4 s, sys: 8.43 s, total: 1min 5s
Wall time: 57.4 s
下一步是查看结果。为了确保我们生成的图表具有一致的x轴和y轴,我们将使用一个小函数来计算一组用于绘图的轴边界。
def axis_bounds(embedding):
left, right = embedding.T[0].min(), embedding.T[0].max()
bottom, top = embedding.T[1].min(), embedding.T[1].max()
adj_h, adj_v = (right - left) * 0.1, (top - bottom) * 0.1
return [left - adj_h, right + adj_h, bottom - adj_v, top + adj_v]
现在只需在十个不同的散点图中绘制结果。我们可以直接使用matplotlib轻松完成此操作,设置一个绘图网格。请注意,进度是按行然后按列进行的,因此请像阅读文本页面一样阅读进度(从左到右,然后向下)。
fig, axs = plt.subplots(5,2, figsize=(10, 20))
ax_bound = axis_bounds(np.vstack(aligned_mapper.embeddings_))
for i, ax in enumerate(axs.flatten()):
current_target = ordered_target[150 * i:min(ordered_target.shape[0], 150 * i + 400)]
ax.scatter(*aligned_mapper.embeddings_[i].T, s=2, c=current_target, cmap="Spectral")
ax.axis(ax_bound)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
因此,尽管在不同的数据集上嵌入不同,聚类仍然保持其总体对齐——左上图和右下图在特定数字聚类的位置上大致相同。我们还可以在一定程度上看到结构在不同切片过程中的变化。因此,我们保持了各种嵌入的对齐,但允许由每个不同数据切片的不同结构所决定的变化。
在线更新对齐嵌入
可能的情况是,我们有传入的时间数据,并且希望有时间窗口的嵌入,理想情况下,这些嵌入与之前时间窗口的嵌入对齐。只要我们重叠使用的时间窗口以允许时间窗口之间的关系,那么这是可能的——除了之前的代码要求所有时间窗口一次性输入以进行拟合。我们更希望训练一个初始模型,然后随着时间推移更新它。这可以通过update方法实现,我们将在下面演示。
首先我们需要拟合一个基础的AlignedUMAP模型;我们将使用前两个切片和第一个关系字典来完成这一步骤。
%%time
updating_mapper = umap.AlignedUMAP().fit(slices[:2], relations=relation_dicts[:1])
CPU times: user 9.32 s, sys: 1.47 s, total: 10.8 s
Wall time: 9.17 s
请注意,这相当快,因为我们只拟合了两个切片。
给定训练好的模型,更新方法需要添加一个新的数据切片,
以及一个关系字典(通过relations关键字参数传递,与fit相同)。这将为模型的新数据在embeddings_属性中追加一个新的嵌入,与迄今为止看到的内容对齐。
for i in range(2,len(slices)):
%time updating_mapper.update(slices[i], relations={v:k for k,v in relation_dicts[i-1].items()})
CPU times: user 7.78 s, sys: 1.15 s, total: 8.93 s
Wall time: 7.92 s
CPU times: user 6.64 s, sys: 1.17 s, total: 7.81 s
Wall time: 6.6 s
CPU times: user 6.94 s, sys: 1.17 s, total: 8.11 s
Wall time: 6.81 s
CPU times: user 6.45 s, sys: 1.51 s, total: 7.96 s
Wall time: 6.45 s
CPU times: user 7.44 s, sys: 1.32 s, total: 8.76 s
Wall time: 7.16 s
CPU times: user 7.68 s, sys: 1.73 s, total: 9.41 s
Wall time: 7.59 s
CPU times: user 7.88 s, sys: 1.65 s, total: 9.54 s
Wall time: 7.39 s
CPU times: user 7.82 s, sys: 1.98 s, total: 9.8 s
Wall time: 7.7 s
请注意,每个新的切片所需的时间相对较短,正如我们所希望的那样。正如你可以想象的,这样做的缺点是,我们没有“向前”的关系——切片上的窗口只能向后看。这意味着结果不那么好,但我们正在用这一点来换取能够快速轻松地更新。
我们可以使用与之前基本相同的代码来查看我们的表现。
fig, axs = plt.subplots(5,2, figsize=(10, 20))
ax_bound = axis_bounds(np.vstack(updating_mapper.embeddings_))
for i, ax in enumerate(axs.flatten()):
current_target = ordered_target[150 * i:min(ordered_target.shape[0], 150 * i + 400)]
ax.scatter(*updating_mapper.embeddings_[i].T, s=2, c=current_target, cmap="Spectral")
ax.axis(ax_bound)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
我们看到对齐确实有效,因此新的切片仍然可以与之前训练的切片进行比较。正如所指出的,整体对齐和进展不如之前的版本好,但它确实有一个显著的优点,即允许您采用逐步更新的方法。
请注意,目前这个模型保留了所有之前的数据,因此它只能在批量流处理的方法中真正有效,即偶尔训练一个新模型,在继续更新之前丢弃一些历史数据。
对齐不同的参数
可以对齐使用不同参数而不是数据的UMAP嵌入。为了演示这是如何工作的,我们将继续使用pendigits数据集,但不像之前那样切片数据,而是使用完整的数据集。这意味着我们的数据集之间的关系只是恒定的关系。我们可以提前构建这些关系:
constant_dict = {i:i for i in range(digits.data.shape[0])}
constant_relations = [constant_dict for i in range(9)]
要在多个参数范围内运行AlignedUMAP,您只需传入您希望使用的参数序列的列表。您可以对几个不同的参数执行此操作——只需确保所有列表的长度相同!在这种情况下,我们将尝试查看如果我们更改n_neighbors和min_dist,嵌入如何变化。这意味着当我们创建AlignedUMAP对象时,我们向这些参数传递一个列表,而不是单个值。为了使可视化更有趣,我们还将改变一些对齐参数(只有两个主要影响)。具体来说,我们将调整alignment_window_size,它控制在执行对齐时我们在数据集中向前和向后查看的距离,以及alignment_regularisation,它控制我们对齐方面与UMAP布局的权重。较大的alignment_regularisation值将更努力地保持嵌入中的点对齐(以每个切片的嵌入质量为代价),而较小的值将允许优化更多地关注单个嵌入,并减少对嵌入之间对齐的重视。
给定一个模型,我们可以随后对其进行拟合。和之前一样,我们需要提供一组数据集和一组关系。由于我们每次使用相同的数据(并改变参数),我们可以简单地重复完整的pendigits数据集。请注意,数据集的数量需要与使用的参数值的数量相匹配。关系数量也是如此(比参数值的数量少一个)。
neighbors_mapper = umap.AlignedUMAP(
n_neighbors=[3,4,5,7,11,16,22,29,37,45,54],
min_dist=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45],
alignment_window_size=2,
alignment_regularisation=1e-3,
).fit(
[digits.data for i in range(10)], relations=constant_relations
)
和之前一样,我们可以通过绘制每个嵌入来查看结果。
fig, axs = plt.subplots(5,2, figsize=(10, 20))
ax_bound = axis_bounds(np.vstack(neighbors_mapper.embeddings_))
for i, ax in enumerate(axs.flatten()):
ax.scatter(*neighbors_mapper.embeddings_[i].T, s=2, c=digits.target, cmap="Spectral")
ax.axis(ax_bound)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
为了更好地感受嵌入随参数值变化的演变,我们可以在三维空间中绘制数据,其中第三维是所选的参数值。为了更好地展示嵌入中的数据点如何随着参数的变化而移动,我们可以将它们绘制为连接每个连续嵌入中相同点的曲线,而不是点。对于这样的三维图,我们将使用plotly绘图库。
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
我们首先要做的是将数据整理成适合plotly的格式。这就是为什么我们也加载了pandas——plotly喜欢数据框。这包括将所有嵌入堆叠在一起,然后根据我们所在的嵌入分配一个额外的z值。为了可视化的目的,我们将为z坐标设置一个从0到1的线性比例,长度适当。
n_embeddings = len(neighbors_mapper.embeddings_)
es = neighbors_mapper.embeddings_
embedding_df = pd.DataFrame(np.vstack(es), columns=('x', 'y'))
embedding_df['z'] = np.repeat(np.linspace(0, 1.0, n_embeddings), es[0].shape[0])
embedding_df['id'] = np.tile(np.arange(es[0].shape[0]), n_embeddings)
embedding_df['digit'] = np.tile(digits.target, n_embeddings)
import scipy.interpolate
插值模块有一个函数 interp1d,它生成一个(向量)平滑函数,给定一组一维数据点,它需要通过这些数据点。我们可以为每个pendigit样本的x和y坐标生成单独的函数,从而允许我们在三维空间中生成平滑曲线。
fx = scipy.interpolate.interp1d(
embedding_df.z[embedding_df.id == 0], embedding_df.x.values.reshape(n_embeddings, digits.data.shape[0]).T, kind="cubic"
)
fy = scipy.interpolate.interp1d(
embedding_df.z[embedding_df.id == 0], embedding_df.y.values.reshape(n_embeddings, digits.data.shape[0]).T, kind="cubic"
)
z = np.linspace(0, 1.0, 100)
有了这些,剩下的就是绘制所有的曲线。在plotly的术语中,每条曲线被称为一个“轨迹”,我们分别生成每条轨迹(并赋予由样本代表的数字所对应的合适颜色)。然后我们将所有轨迹添加到一个图形中,并显示该图形。
palette = px.colors.diverging.Spectral
interpolated_traces = [fx(z), fy(z)]
traces = [
go.Scatter3d(
x=interpolated_traces[0][i],
y=interpolated_traces[1][i],
z=z*3.0,
mode="lines",
line=dict(
color=palette[digits.target[i]],
width=3.0
),
opacity=1.0,
)
for i in range(digits.data.shape[0])
]
fig = go.Figure(data=traces)
fig.update_layout(
width=800,
height=700,
autosize=False,
showlegend=False,
)
fig.show()
由于在文档中嵌入交互式plotly图表较为复杂,我们在此提供了一个静态图像,但如果您自己运行此代码,您将能够完全交互式地查看数据。
from IPython.display import display, Image, HTML
from matplotlib import animation
fig = plt.figure(figsize=(4, 4), dpi=150)
ax = fig.add_subplot(1, 1, 1)
scat = ax.scatter([], [], s=2)
scat.set_array(digits.target)
scat.set_cmap('Spectral')
text = ax.text(ax_bound[0] + 0.5, ax_bound[2] + 0.5, '')
ax.axis(ax_bound)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
offsets = np.array(interpolated_traces).T
num_frames = offsets.shape[0]
def animate(i):
scat.set_offsets(offsets[i])
text.set_text(f'Frame {i}')
return scat
anim = animation.FuncAnimation(
fig,
init_func=None,
func=animate,
frames=num_frames,
interval=40)
然后我们可以将动画保存为GIF并关闭我们的动画。根据你的机器,你可能需要更改保存方法使用的写入器。
anim.save("aligned_umap_pendigits_anim.gif", writer="pillow")
plt.close(anim._fig)
最后,我们可以读取渲染的GIF并在笔记本中显示它。
with open("aligned_umap_pendigits_anim.gif", "rb") as f:
display(Image(f.read()))