原型选择#
在本笔记本中,我们展示了一个从源数据集中选择代表目标数据集的典型示例的例子。我们实验了流行的数字数据集。随机创建了两个分区,src和tgt,分别对应于源集和目标集。我们的方法利用最优传输理论,通过将原型分布与目标tgt分布匹配,从src中学习原型。
这个笔记本可以在我们的examples folder在GitHub上找到。
# install interpret if not already installed
try:
import interpret
except ModuleNotFoundError:
!pip install --quiet interpret numpy scikit-learn matplotlib
我们加载所需的包。特定于原型选择算法的包/文件是“SPOTgreedy”。
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import pairwise_distances
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
from interpret.utils import SPOT_GreedySubsetSelection # This loads the SPOT prototype selection algorithm.
我们现在加载数字数据集,并通过将数字数据分成70/30的分区来创建src和tgt集。
# Load the digits dataset
digits = load_digits()
# Flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Split data into 70% src and 30% tgt subsets
X_src, X_tgt, y_src, y_tgt = train_test_split(
data, digits.target, test_size=0.3, shuffle=False)
需要源点和目标点之间的成对距离/不相似性。最优传输框架允许使用任何距离/不相似性度量。在本例中,我们使用欧几里得距离度量。
# Compute the Euclidean distances between the X_src (source) and X_tgt (target) points.
C = pairwise_distances(X_src, X_tgt, metric='euclidean');
targetmarginal 是目标点上的经验分布。通常假设它是均匀的,即每个目标点都被赋予相同的重要性。在实验中,我们讨论了两种设置。在第一种设置中,我们假设 targetmarginal 是均匀的。在第二种设置中,我们对特定类别的点的 targetmarginal 进行了偏斜。实验表明,在这两种设置中,学习到的原型都能很好地代表目标分布 targetmarginal。
设置1:目标分布是均匀的
# Define a targetmarginal on the target set
# We define the uniform marginal
targetmarginal = np.ones(C.shape[1])/C.shape[1];
# The number of prototypes to be computed
numprototypes = 20;
# Run SPOTgreedy
# prototypeIndices represent the indices corresponding to the chosen prototypes.
# prototypeWeights represent the weights associated with each of the chosen prototypes. The weights sum to 1.
[prototypeIndices, prototypeWeights] = SPOT_GreedySubsetSelection(C, targetmarginal, numprototypes);
# Plot the chosen prototypes
fig, axs = plt.subplots(nrows=5, ncols=4, figsize=(2, 2))
for idx, ax in enumerate(axs.ravel()):
ax.imshow(data[prototypeIndices[idx]].reshape((8, 8)), cmap=plt.cm.binary)
ax.axis("off")
_ = fig.suptitle("Top prototypes selected from the 64-dimensional digit dataset with uniform target distribution", fontsize=16)
设置2:目标分布是偏斜的
在这种设置中,我们将tgt中对应标签3的示例倾斜了90%。我们预期大多数学习到的原型也属于标签3。
# Skew the target marginal to give weights to specific classes more
result = np.where(y_tgt == 3); # find indices corresponding to label 3.
targetmarginal_skewed = np.ones(C.shape[1]);
targetmarginal_skewed[result[0]] = 90; # Weigh the instances corresponding to label 3 more.
targetmarginal_skewed = targetmarginal_skewed/np.sum(targetmarginal_skewed);
# Run SPOTgreedy
[prototypeIndices_skewed, prototypeWeights_skewed] = SPOT_GreedySubsetSelection(C, targetmarginal_skewed, numprototypes);
# Plot the prototypes selected
fig, axs = plt.subplots(nrows=5, ncols=4, figsize=(2, 2))
for idx, ax in enumerate(axs.ravel()):
ax.imshow(data[prototypeIndices_skewed[idx]].reshape((8, 8)), cmap=plt.cm.binary)
ax.axis("off")
_ = fig.suptitle("Top prototypes selected from the 64-dimensional digit dataset with skewed target distribution", fontsize=16)