二维准确和熵不平衡最优运输的示例

本示例旨在展示如何在POT中计算不平衡和部分OT。

UOT旨在解决以下优化问题:

\[ \begin{align}\begin{aligned}W = \min_{\gamma} <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]

其中 \(\mathrm{div}\) 是散度。使用熵式 UOT 时,\(\mathrm{reg}>0\)\(\mathrm{div}\) 应该是 Kullback-Leibler 散度。当解决精确 UOT 时,\(\mathrm{reg}=0\)\(\mathrm{div}\) 可以是 Kullback-Leibler 或二次散度。使用 \(\ell_1\) 范数会得到所谓的部分 OT。

# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot

生成数据

n = 40  # nb samples

mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -0.8], [-0.8, 1]])

np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

n_noise = 10

xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0)
xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0)

n = n + n_noise

a, b = np.ones((n,)) / n, np.ones((n,)) / n  # uniform distribution on samples

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()

计算熵kl正则化UOT、kl和l2正则化UOT

绘制结果

pl.figure(2)
transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
title = [
    "partial OT \n m=" + str(mass),
    "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + str(reg_m_l2),
    "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
    "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
]

for p in range(4):
    pl.subplot(2, 4, p + 1)
    P = transp[p]
    if P.sum() > 0:
        P = P / P.max()
    for i in range(n):
        for j in range(n):
            if P[i, j] > 0:
                pl.plot(
                    [xs[i, 0], xt[j, 0]],
                    [xs[i, 1], xt[j, 1]],
                    color="C2",
                    alpha=P[i, j] * 0.3,
                )
    pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2)
    pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2)
    pl.scatter(xs[:, 0], xs[:, 1], c="C0", s=P.sum(1).ravel() * (1 + p) * 2)
    pl.scatter(xt[:, 0], xt[:, 1], c="C1", s=P.sum(0).ravel() * (1 + p) * 2)
    pl.title(title[p])
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("mappings")
    pl.subplot(2, 4, p + 5)
    pl.imshow(P, cmap="jet")
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("transport plans")
pl.show()
partial OT   m=0.7, $\ell_2$-UOT   $\mathrm{reg_m}$=5, kl-UOT   $\mathrm{reg_m}$=0.05, entropic kl-UOT   $\mathrm{reg_m}$=0.05

脚本的总运行时间: (0 分钟 3.601 秒)

由 Sphinx-Gallery 生成的画廊