UMAP 支持逆变换——在低维嵌入空间中给定一个位置,生成一个高维数据样本。首先,让我们加载所有相关的库。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import sklearn.datasets
import umap
import umap.plot
我们需要一些数据来进行测试。首先,我们将使用MNIST数字数据集。这是一个包含70000个手写数字的数据集,这些数字被编码为28x28像素的灰度图像。我们的目标是使用UMAP将这个数据集的维度降低到较小的尺寸,然后看看我们是否可以通过从嵌入空间中采样点来生成新的数字。为了加载MNIST数据集,我们将使用sklearn的fetch_openml函数。
data, labels = sklearn.datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
mapper = umap.UMAP(random_state=42).fit(data)
为了确保一切正常工作,我们可以绘制数据(因为我们将其降到了两个维度)。我们将使用umap.plot功能来实现这一点。
umap.plot.points(mapper, labels=labels)
这看起来和我们预期的差不多。不同的数字类别已经被很好地分开了。现在我们需要在嵌入空间中创建一组样本来应用inverse_transform操作。为此,我们将在四个角点之间线性插值生成一个样本网格。为了使我们的选择更有趣,我们将仔细选择跨越数据集的角点,并采样不同的数字,以便更好地看到过渡。
corners = np.array([
[-5, -10], # 1
[-7, 6], # 7
[2, -8], # 2
[12, 4], # 0
])
test_pts = np.array([
(corners[0]*(1-x) + corners[1]*x)*(1-y) +
(corners[2]*(1-x) + corners[3]*x)*y
for y in np.linspace(0, 1, 10)
for x in np.linspace(0, 1, 10)
])
现在我们可以对这套测试点应用inverse_transform方法。每个测试点都是嵌入空间中的二维点。inverse_transform方法会将其转换为高维表示的近似值,这些高维表示原本会被嵌入到这样的位置。按照sklearn API,这就像调用训练模型的inverse_transform方法并传递我们想要转换为高维表示的测试点集一样简单。需要注意的是,这在计算上可能会非常昂贵。
inv_transformed_points = mapper.inverse_transform(test_pts)
现在的目标是可视化我们的表现如何。实际上,我们想要做的是在嵌入空间中显示测试点,然后显示由逆变换生成的相应图像的网格。为了在单个matplotlib图中展示所有这些内容,需要一些设置,但这是相当容易管理的——主要是管理GridSpec格式。一旦我们完成了设置,我们只需要一个嵌入的散点图,一个测试点的散点图,最后是我们生成的图像的网格(将逆变换后的向量转换回图像只是将它们重新调整为28x28像素网格并使用imshow)。
# Set up the grid
fig = plt.figure(figsize=(12,6))
gs = GridSpec(10, 20, fig)
scatter_ax = fig.add_subplot(gs[:, :10])
digit_axes = np.zeros((10, 10), dtype=object)
for i in range(10):
for j in range(10):
digit_axes[i, j] = fig.add_subplot(gs[i, 10 + j])
# Use umap.plot to plot to the major axis
# umap.plot.points(mapper, labels=labels, ax=scatter_ax)
scatter_ax.scatter(mapper.embedding_[:, 0], mapper.embedding_[:, 1],
c=labels.astype(np.int32), cmap='Spectral', s=0.1)
scatter_ax.set(xticks=[], yticks=[])
# Plot the locations of the text points
scatter_ax.scatter(test_pts[:, 0], test_pts[:, 1], marker='x', c='k', s=15)
# Plot each of the generated digit images
for i in range(10):
for j in range(10):
digit_axes[i, j].imshow(inv_transformed_points[i*10 + j].reshape(28, 28))
digit_axes[i, j].set(xticks=[], yticks=[])
最终结果看起来相当不错——我们确实生成了看起来合理的数字图像,许多过渡(例如顶行从1到7的过渡)看起来相当自然且有意义。这可以帮助你理解1的簇的结构(它在角度上过渡,向最终成为7的方向倾斜),以及为什么7和9在嵌入中靠得很近。当然,也有一些奇怪的过渡,特别是在测试点落入嵌入中簇之间的大间隙时——在某种意义上,很难解释这些间隙中应该放什么,因为它们并不真正代表任何类似于平滑过渡的东西。
进一步说明:所选的测试点均未落在嵌入的凸包之外。这是有意为之——逆变换函数在凸包边界之外的表现不佳。请注意,如果您选择要逆变换的点位于嵌入的凸包边界之外,您可能会得到奇怪的结果(通常只是简单地捕捉到某个特定的高维源向量)。
让我们继续演示,看看Fashion MNIST数据集。和之前一样,我们可以通过sklearn加载它。
data, labels = sklearn.datasets.fetch_openml('Fashion-MNIST', version=1, return_X_y=True)
再次,我们可以使用UMAP拟合这些数据并获得一个映射器对象。
mapper = umap.UMAP(random_state=42).fit(data)
让我们绘制嵌入图,看看我们得到了什么结果:
umap.plot.points(mapper, labels=labels)
再次,我们将通过在四个角之间进行网格插值来生成一组测试点。和之前一样,我们将选择这些角,以便我们可以保持在嵌入点的凸包内,并确保逆变换不会发生任何太奇怪的事情。
corners = np.array([
[-2, -6], # bags
[-9, 3], # boots?
[7, -5], # shirts/tops/dresses
[4, 10], # pants
])
test_pts = np.array([
(corners[0]*(1-x) + corners[1]*x)*(1-y) +
(corners[2]*(1-x) + corners[3]*x)*y
for y in np.linspace(0, 1, 10)
for x in np.linspace(0, 1, 10)
])
现在我们只需像之前一样应用逆变换。再次提醒,这在计算上非常昂贵,可能需要一些时间才能完成。
inv_transformed_points = mapper.inverse_transform(test_pts)
现在我们可以使用与上面类似的代码来设置我们的嵌入图,并在其上叠加测试点,以及生成的图像。
# Set up the grid
fig = plt.figure(figsize=(12,6))
gs = GridSpec(10, 20, fig)
scatter_ax = fig.add_subplot(gs[:, :10])
digit_axes = np.zeros((10, 10), dtype=object)
for i in range(10):
for j in range(10):
digit_axes[i, j] = fig.add_subplot(gs[i, 10 + j])
# Use umap.plot to plot to the major axis
# umap.plot.points(mapper, labels=labels, ax=scatter_ax)
scatter_ax.scatter(mapper.embedding_[:, 0], mapper.embedding_[:, 1],
c=labels.astype(np.int32), cmap='Spectral', s=0.1)
scatter_ax.set(xticks=[], yticks=[])
# Plot the locations of the text points
scatter_ax.scatter(test_pts[:, 0], test_pts[:, 1], marker='x', c='k', s=15)
# Plot each of the generated digit images
for i in range(10):
for j in range(10):
digit_axes[i, j].imshow(inv_transformed_points[i*10 + j].reshape(28, 28))
digit_axes[i, j].set(xticks=[], yticks=[])
这次我们看到一些项目之间的插值看起来相当奇怪——特别是那些位于鞋子和裤子之间的点——最终它在一个困难的问题上尽了最大努力。同时,许多其他过渡似乎效果相当不错,因此它确实提供了有关嵌入结构的有用信息。