绘制UMAP结果

UMAP 通常用于通过将数据降维到二维来进行可视化。 由于这是一个非常常见的用例,umap 包现在包含了一些实用程序, 使得绘制 UMAP 结果变得简单,并提供了多种查看和诊断结果的方式。 与其试图提供一个涵盖所有可能绘图需求的全面解决方案, 这个 umap 扩展旨在提供一个简单易用的界面, 使大多数绘图需求变得容易,并尽可能提供合理的绘图选择。 要开始查看绘图选项,让我们加载一些数据来进行操作。

import sklearn.datasets
import pandas as pd
import numpy as np
import umap
pendigits = sklearn.datasets.load_digits()
mnist = sklearn.datasets.fetch_openml('mnist_784')
fmnist = sklearn.datasets.fetch_openml('Fashion-MNIST')

首先,我们将为pendigits数据拟合一个UMAP模型。这就像运行fit方法并将结果赋值给一个变量一样简单。

mapper = umap.UMAP().fit(pendigits.data)

如果我们想要进行绘图,我们将需要umap.plot包。虽然umap包的依赖项相对较少,但值得注意的是,如果你想使用umap.plot,你将需要一些额外的库,这些库不在umap的默认依赖项中。特别是你需要:

所有内容都应该可以通过pip或conda安装。有了这些工具,你可以导入umap.plot包。

import umap.plot

现在我们已加载了包,如何使用它呢?最直接的方法是绘制umap结果为点。我们可以通过函数umap.plot.points来实现。在最基本的形式中,你可以简单地将训练好的UMAP模型传递给umap.plot.points

umap.plot.points(mapper)
_images/plotting_8_2.png

如你所见,我们立即得到了UMAP嵌入的散点图。 请注意,该函数会根据数据密度自动选择点的大小,并在图像上添加使用过的UMAP参数的水印(如果度量标准是非标准的,也会包括在内)。 该函数还返回与图表相关联的matplotlib轴对象,因此如果需要,用户可以应用进一步的matplotlib函数,例如添加标题、轴标签等。

传递给UMAP的数据通常带有一组相关的标签,这些标签可能来自真实数据、聚类或其他方式。在这种情况下,希望能够根据标签对散点图进行着色。我们可以通过简单地将标签信息数组与labels关键字一起传递来实现这一点。umap.plot.points函数将根据提供的标签使用分类颜色映射对数据进行着色。

umap.plot.points(mapper, labels=pendigits.target)
_images/plotting_10_1.png

或者,您可能有一些额外的数据是连续的而不是分类的。在这种情况下,您会希望使用连续的颜色映射来为数据着色。同样,这也很简单——使用values关键字传入连续数据,数据将根据连续颜色映射进行着色。

此外,如果你不喜欢默认的颜色选择,umap.plot.points 函数提供了许多“主题”,这些主题提供了预定义的颜色选择。主题包括:

  • 范例

  • 地狱

  • 蓝色

  • 红色

  • 绿色

  • 深蓝色

  • 深红色

  • 深绿色

这里我们将使用‘fire’主题来演示改变美学是多么简单。

umap.plot.points(mapper, values=pendigits.data.mean(axis=1), theme='fire')
_images/plotting_12_1.png

如果您想要更大的控制权,可以指定确切的颜色映射和背景颜色。例如,这里我们希望通过标签为数据着色,但使用黑色背景,并使用‘Paired’颜色映射进行分类着色(作为color_key_cmap传递;cmap关键字定义了连续的颜色映射)。

umap.plot.points(mapper, labels=pendigits.target, color_key_cmap='Paired', background='black')
_images/plotting_14_1.png

还有许多其他选项可用,包括color_key用于指定离散标签到颜色的字典映射,cmap用于指定连续的颜色映射,或者指定生成图的宽度和高度。 再次强调,这并不提供对图表美学的全面控制,但这里的目标是提供一个简单易用的界面,而不是让用户能够微调所有方面——寻求这种控制的用户最好单独使用底层包(matplotlib、datashader和bokeh)。

绘制更大的数据集

一旦你有了大量数据,简单的散点图就更容易欺骗你。最明显的是过度绘图,其中点的标记重叠并堆积在一起,可能会让你误以为极其密集的簇可能只包含几个点。虽然有一些方法可以帮助解决这个问题,比如减小点的大小,或者添加一个alpha通道,但很少有方法足以确保图表不会以某种方式微妙地欺骗你。这篇文章在datashader文档中很好地描述了过度绘图的问题,为什么显而易见的解决方案并不完全足够,以及如何绕过这个问题。为了让用户更轻松,umap.plot包会在你的数据集变得足够大时自动切换到使用datashader进行渲染。这有助于确保你不会被过度绘图所欺骗。我们可以通过处理一个较大的数据集(如Fashion-MNIST)来看到这一点。

mapper = umap.UMAP().fit(fmnist.data)

在使用UMAP拟合数据后,我们可以像之前一样调用umap.plot.points,但这次,由于数据量足够大,可能会出现过度绘制的情况,因此将在后台使用datashader进行渲染。

umap.plot.points(mapper)
_images/plotting_19_2.png

所有之前的绘图选项仍然适用,因此我们可以按标签着色,并应用相同的主题,所有这些都将无缝地使用datashader进行实际渲染。因此,无论你有多少数据,umap.plot.points都会以透明的用户界面很好地渲染它。作为用户,你不需要担心切换到使用datashader进行绘图,或者如何将你的绘图转换为其略有不同的API——你可以继续使用相同的API并信任你得到的结果。

umap.plot.points(mapper, labels=fmnist.target, theme='fire')
_images/plotting_21_2.png

交互式绘图和悬停工具

渲染美观的静态图表很重要,但如果你想能够与数据互动——平移、放大集群以查看更精细的结构呢?如果你想用比仅仅颜色更复杂的标签来注释你的数据呢?能够悬停在数据点上并获取有关单个点的更多信息不是很好吗?由于这是一个非常常见的用例,umap.plot 试图使快速生成此类图表变得容易,并提供基本工具,以便你能够快速使用带注释的悬停工具。再次强调,目标不是提供一个可以完成所有任务的全面解决方案,而是一个简单易用且一致的API,让用户能够快速上手。

为了更好地说明这一点,让我们使用Fashion MNIST数据集的一个子集。我们可以快速在该数据集上训练一个新的映射器对象。

mapper = umap.UMAP().fit(fmnist.data[:30000])

目标是能够悬停在不同的点上,并查看与光标下的给定点(或点)相关联的数据。在这个简单的演示中,我们将仅使用点的目标信息。要创建悬停信息,您需要构建一个包含您希望在悬停中显示的所有数据的数据框。每一行应对应于数据点的一个来源(以相同的顺序出现),而列可以提供您希望在悬停工具提示中显示的任何额外数据。在这种情况下,我们需要一个可以包括点的索引、其目标编号以及该目标对应的时尚物品类型的实际名称的数据框。使用pandas可以快速轻松地完成此操作。

hover_data = pd.DataFrame({'index':np.arange(30000),
                           'label':fmnist.target[:30000]})
hover_data['item'] = hover_data.label.map(
    {
        '0':'T-shirt/top',
        '1':'Trouser',
        '2':'Pullover',
        '3':'Dress',
        '4':'Coat',
        '5':'Sandal',
        '6':'Shirt',
        '7':'Sneaker',
        '8':'Bag',
        '9':'Ankle Boot',
    }
)

对于交互式使用,umap.plot 包使用了 bokeh。Bokeh 有几种输出方法,但在我们的方法中,我们将在笔记本中内联输出。我们必须使用 output_notebook 函数来启用此功能。或者,我们可以使用 output_file 或其他类似选项——更多详情请参阅 bokeh 文档。

umap.plot.output_notebook()
正在加载 BokehJS ...

现在我们可以使用umap.plot.interactive来制作一个交互式图表。 这与umap.plot.points方法非常相似,但 还支持一个hover_data关键字,如果传递一个合适的 数据框,将在交互式图表中提供悬停提示。由于 bokeh允许不同的输出,要在笔记本中显示它,我们 必须额外调用show来显示结果。

p = umap.plot.interactive(mapper, labels=fmnist.target[:30000], hover_data=hover_data, point_size=2)
umap.plot.show(p)
Bokeh Plot

绘制连接性

UMAP通过构建数据可能从中采样的近似流形的中间拓扑表示来工作。实际上,这种结构可以简化为一个加权图。有时,查看该图(表示流形中的连接性)相对于生成的嵌入的外观可能是有益的。它可以用于更好地理解嵌入,并用于诊断目的。要查看连接性,您可以使用umap.plot.connectivity函数。它的工作方式与umap.plot.points函数非常相似,并且可以选择是否显示嵌入点或仅显示连接性。首先,让我们做一个简单的图来显示这些点:

umap.plot.connectivity(mapper, show_points=True)
_images/plotting_32_2.png

umap.plot.points一样,有一些选项可以控制基本的美学效果,包括主题选项和一个edge_cmap关键字参数,用于指定显示边缘时使用的颜色映射。

umap.plot.connectivity(mapper, edge_bundling='hammer')
_images/plotting_34_2.png

诊断绘图

绘制连接性至少提供了一个基本的诊断视图,帮助用户理解嵌入的情况。当然,更多的数据视图更好,因此umap.plot包含了一个umap.plot.diagnostic函数,可以提供各种诊断图。我们将在这里查看其中的一些。为此,我们将使用完整的MNIST数字数据集。

mapper = umap.UMAP().fit(mnist.data)

第一种诊断类型是基于主成分分析的诊断,您可以通过diagnostic_type='pca'来选择。该方法的本质是我们可以使用PCA(保留全局结构)将数据降维到三维。如果我们将结果缩放到适合3D立方体的范围内,我们可以将每个点的3D PCA坐标转换为颜色的RGB描述。然后通过用PCA诱导的颜色为UMAP嵌入中的点着色,可以了解一些更大规模的全局结构在嵌入中的表示情况。

umap.plot.diagnostic(mapper, diagnostic_type='pca')
_images/plotting_38_1.png

我们在这里寻找的是颜色的一般平滑过渡,以及总体上尊重颜色过渡的布局。在这种情况下,最左边有一个底部集群,从底部的深绿色过渡到顶部的蓝色,这与右上角的集群很好地匹配,后者在底部有类似的蓝色阴影,然后过渡到更多的青色和蓝色。相比之下,在图的右侧,下部的集群从上到下从紫粉色过渡到绿色,而其上方的集群的底部边缘更偏向紫色而不是绿色,这表明在优化过程中,这些集群中的一个或另一个可能被垂直翻转,而这从未得到完全纠正。

另一种类似的方法是使用向量量化作为生成3D嵌入以生成颜色的方法。向量量化有效地找到数据的3个代表性中心,然后根据每个数据点到这些中心的距离来描述数据点。显然,这再次捕捉到了数据的广泛全局结构。

umap.plot.diagnostic(mapper, diagnostic_type='vq')
_images/plotting_40_1.png

我们再次寻找大致平滑的过渡,以及相关颜色在集群之间的匹配。这个视图支持嵌入左侧效果良好的事实,但观察右侧,很明显是上方的两个集群无意中被垂直翻转了。通过对比这样的视图,可以更好地了解嵌入的效果如何。

为了从不同的角度,我们可以查看每个数据点周围的局部维度的近似值。理想情况下,局部维度应与嵌入维度相匹配(尽管这通常是一个很高的期望。在实践中,当局部维度较高时,这表示UMAP在嵌入这些点(或空间的区域)时会遇到更大的困难。因此,在局部维度较低的区域,可以更信任嵌入的准确性。

local_dims = umap.plot.diagnostic(mapper, diagnostic_type='local_dim')
_images/plotting_42_0.png

正如你所见,数据的局部维度在整个数据中变化很大。特别是左下角的簇具有最低的局部维度——这实际上并不令人惊讶,因为这是对应于数字1的簇:人们绘制数字1的方式相对较少,因此产生的局部维度较低。相比之下,中间的簇具有更高的局部维度。我们应该预期在这些区域中嵌入的准确性会稍低一些:在仅有两个维度的情况下很难很好地表示七维数据,因此需要做出一些妥协。

我们将查看的最终诊断是局部邻域的保持情况。我们可以通过比较高维空间中的局部邻域与嵌入中的等效邻域的Jaccard指数来衡量这一点。Jaccard指数本质上是两个邻域共有的邻居数量与两个邻域中唯一邻居总数的比率。较高的值意味着局部邻域被更准确地保持了。

umap.plot.diagnostic(mapper, diagnostic_type='neighborhood')
_images/plotting_44_1.png

正如人们可能预期的那样,局部邻域保持对于局部维度较低的点(如最后一图所示)往往要好得多。此外,集群的边缘(有明确的边界需要遵循)比局部维度较高的集群中心更有可能保持邻域。这再次提供了一个视角,让你可以更信任嵌入的哪些区域,以及哪些区域在嵌入到二维时必须做出妥协。