OTDA 无监督与半监督设置

本示例介绍了在二维环境中的半监督领域适应。它明确了半监督领域适应的问题,并引入了一些最优传输方法来解决它。

为了给运输方法的作用提供直观理解,量值如最佳耦合、较大的耦合系数和运输样本被表示出来。

# Authors: Remi Flamary <remi.flamary@unice.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import matplotlib.pylab as pl
import ot

生成数据

将源样本传输到目标样本

# unsupervised domain adaptation
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)

# semi-supervised domain adaptation
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)

# semi supervised DA uses available labeled target samples to modify the cost
# matrix involved in the OT problem. The cost of transporting a source sample
# of class A onto a target sample of class B != A is set to infinite, or a
# very large value

# note that in the present case we consider that all the target samples are
# labeled. For daily applications, some target sample might not have labels,
# in this case the element of yt corresponding to these samples should be
# filled with -1.

# Warning: we recall that -1 cannot be used as a class label

图 1:绘制源样本和目标样本 + 成对距离矩阵

pl.figure(1, figsize=(10, 10))
pl.subplot(2, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Source  samples")

pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Target samples")

pl.subplot(2, 2, 3)
pl.imshow(ot_sinkhorn_un.cost_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Cost matrix - unsupervised DA")

pl.subplot(2, 2, 4)
pl.imshow(ot_sinkhorn_semi.cost_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Cost matrix - semi-supervised DA")

pl.tight_layout()

# the optimal coupling in the semi-supervised DA case will exhibit " shape
# similar" to the cost matrix, (block diagonal matrix)
Source  samples, Target samples, Cost matrix - unsupervised DA, Cost matrix - semi-supervised DA

图 2 : 绘制不同方法的最佳耦合

pl.figure(2, figsize=(8, 4))

pl.subplot(1, 2, 1)
pl.imshow(ot_sinkhorn_un.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nUnsupervised DA")

pl.subplot(1, 2, 2)
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSemi-supervised DA")

pl.tight_layout()
Optimal coupling Unsupervised DA, Optimal coupling Semi-supervised DA

图 3 : 绘制运输样本

# display transported samples
pl.figure(4, figsize=(8, 4))
pl.subplot(1, 2, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
    transp_Xs_sinkhorn_un[:, 0],
    transp_Xs_sinkhorn_un[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.title("Transported samples\nEmdTransport")
pl.legend(loc=0)
pl.xticks([])
pl.yticks([])

pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
    transp_Xs_sinkhorn_semi[:, 0],
    transp_Xs_sinkhorn_semi[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.title("Transported samples\nSinkhornTransport")
pl.xticks([])
pl.yticks([])

pl.tight_layout()
pl.show()
Transported samples EmdTransport, Transported samples SinkhornTransport

脚本的总运行时间: (0 分钟 0.716 秒)

由 Sphinx-Gallery 生成的画廊