注意
跳转到末尾 以下载完整示例代码。
基于经验分布的领域适应OT
本例介绍了在二维环境中的领域适应。它明确了领域适应的问题,并引入了一些最佳传输方法来解决它。
为了给运输方法的作用提供直观理解,量值如最佳耦合、较大的耦合系数和运输样本被表示出来。
# Authors: Remi Flamary <remi.flamary@unice.fr>
# Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import matplotlib.pylab as pl
import ot
import ot.plot
生成数据
n_samples_source = 150
n_samples_target = 150
Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source)
Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target)
# Cost matrix
M = ot.dist(Xs, Xt, metric="sqeuclidean")
实例化不同的传输算法并进行拟合
# EMD Transport
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)
# Sinkhorn Transport
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
# Sinkhorn Transport with Group lasso regularization
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
# transport source samples onto target samples
transp_Xs_emd = ot_emd.transform(Xs=Xs)
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
/home/circleci/project/ot/bregman/_sinkhorn.py:903: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
warnings.warn(
/home/circleci/project/ot/bregman/_sinkhorn.py:667: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
warnings.warn(
图 1:绘制源样本和目标样本 + 成对距离矩阵
pl.figure(1, figsize=(10, 10))
pl.subplot(2, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Source samples")
pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Target samples")
pl.subplot(2, 2, 3)
pl.imshow(M, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Matrix of pairwise distances")
pl.tight_layout()

图 2 : 绘制不同方法的最佳耦合
pl.figure(2, figsize=(10, 6))
pl.subplot(2, 3, 1)
pl.imshow(ot_emd.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nEMDTransport")
pl.subplot(2, 3, 2)
pl.imshow(ot_sinkhorn.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSinkhornTransport")
pl.subplot(2, 3, 3)
pl.imshow(ot_lpl1.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSinkhornLpl1Transport")
pl.subplot(2, 3, 4)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nEMDTransport")
pl.subplot(2, 3, 5)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nSinkhornTransport")
pl.subplot(2, 3, 6)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nSinkhornLpl1Transport")
pl.tight_layout()

图 3 : 绘制运输样本
# display transported samples
pl.figure(4, figsize=(10, 4))
pl.subplot(1, 3, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
transp_Xs_emd[:, 0],
transp_Xs_emd[:, 1],
c=ys,
marker="+",
label="Transp samples",
s=30,
)
pl.title("Transported samples\nEmdTransport")
pl.legend(loc=0)
pl.xticks([])
pl.yticks([])
pl.subplot(1, 3, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
transp_Xs_sinkhorn[:, 0],
transp_Xs_sinkhorn[:, 1],
c=ys,
marker="+",
label="Transp samples",
s=30,
)
pl.title("Transported samples\nSinkhornTransport")
pl.xticks([])
pl.yticks([])
pl.subplot(1, 3, 3)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
transp_Xs_lpl1[:, 0],
transp_Xs_lpl1[:, 1],
c=ys,
marker="+",
label="Transp samples",
s=30,
)
pl.title("Transported samples\nSinkhornLpl1Transport")
pl.xticks([])
pl.yticks([])
pl.tight_layout()
pl.show()

脚本的总运行时间: (0分钟 10.160秒)