注意
跳转到末尾 以下载完整示例代码。
OTDA 无监督与半监督设置
本示例介绍了在二维环境中的半监督领域适应。它明确了半监督领域适应的问题,并引入了一些最优传输方法来解决它。
为了给运输方法的作用提供直观理解,量值如最佳耦合、较大的耦合系数和运输样本被表示出来。
# Authors: Remi Flamary <remi.flamary@unice.fr>
# Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
import matplotlib.pylab as pl
import ot
生成数据
n_samples_source = 150
n_samples_target = 150
Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source)
Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target)
将源样本传输到目标样本
# unsupervised domain adaptation
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
# semi-supervised domain adaptation
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
# semi supervised DA uses available labeled target samples to modify the cost
# matrix involved in the OT problem. The cost of transporting a source sample
# of class A onto a target sample of class B != A is set to infinite, or a
# very large value
# note that in the present case we consider that all the target samples are
# labeled. For daily applications, some target sample might not have labels,
# in this case the element of yt corresponding to these samples should be
# filled with -1.
# Warning: we recall that -1 cannot be used as a class label
图 1:绘制源样本和目标样本 + 成对距离矩阵
pl.figure(1, figsize=(10, 10))
pl.subplot(2, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Source samples")
pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Target samples")
pl.subplot(2, 2, 3)
pl.imshow(ot_sinkhorn_un.cost_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Cost matrix - unsupervised DA")
pl.subplot(2, 2, 4)
pl.imshow(ot_sinkhorn_semi.cost_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Cost matrix - semi-supervised DA")
pl.tight_layout()
# the optimal coupling in the semi-supervised DA case will exhibit " shape
# similar" to the cost matrix, (block diagonal matrix)

图 2 : 绘制不同方法的最佳耦合
pl.figure(2, figsize=(8, 4))
pl.subplot(1, 2, 1)
pl.imshow(ot_sinkhorn_un.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nUnsupervised DA")
pl.subplot(1, 2, 2)
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSemi-supervised DA")
pl.tight_layout()

图 3 : 绘制运输样本
# display transported samples
pl.figure(4, figsize=(8, 4))
pl.subplot(1, 2, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
transp_Xs_sinkhorn_un[:, 0],
transp_Xs_sinkhorn_un[:, 1],
c=ys,
marker="+",
label="Transp samples",
s=30,
)
pl.title("Transported samples\nEmdTransport")
pl.legend(loc=0)
pl.xticks([])
pl.yticks([])
pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
transp_Xs_sinkhorn_semi[:, 0],
transp_Xs_sinkhorn_semi[:, 1],
c=ys,
marker="+",
label="Transp samples",
s=30,
)
pl.title("Transported samples\nSinkhornTransport")
pl.xticks([])
pl.yticks([])
pl.tight_layout()
pl.show()

脚本的总运行时间: (0分钟 1.388秒)