圆上的OT距离

展示如何计算圆上的Wasserstein距离

# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
import ot

from scipy.special import iv

绘制数据

def pdf_von_Mises(theta, mu, kappa):
    pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
    return pdf


t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)

mu1 = 1
kappa1 = 20

mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)


pdf1 = pdf_von_Mises(t, mu1, kappa1)


pl.figure(1)
for k, mu in enumerate(mu_targets):
    pdf_t = pdf_von_Mises(t, mu, kappa1)
    if k == 0:
        label = "Source distributions"
    else:
        label = None
    pl.plot(t / (2 * np.pi), pdf_t, c="b", label=label)

pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
pl.legend()

mu2 = 0
kappa2 = kappa1

x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi

angles = np.linspace(0, 2 * np.pi, 150)

pl.figure(2)
pl.plot(np.cos(angles), np.sin(angles), c="k")
pl.xlim(-1.25, 1.25)
pl.ylim(-1.25, 1.25)
pl.scatter(np.cos(x1), np.sin(x1), c="b")
pl.scatter(np.cos(x2), np.sin(x2), c="r")
  • plot compute wasserstein circle
  • plot compute wasserstein circle
<matplotlib.collections.PathCollection object at 0x7f590d52cca0>

比较欧几里得沃瑟斯坦距离与圆上的沃瑟斯坦距离

此示例说明了沃瑟斯坦距离在圆周上的周期性。 我们选择作为目标分布的为均值 \(\mu_{\mathrm{target}}\)\(\kappa=20\) 的冯-米塞斯分布。然后,我们比较与参数为 \(\mu_{\mathrm{source}}\)\(\kappa=20\) 的冯-米塞斯分布获得的样本的距离。 圆周上的沃瑟斯坦距离考虑了周期性 并在 \(\mu_{\mathrm{target}}+1\)(对顶点)达到最大值,与 欧几里得版本相反。

mu_targets = np.linspace(0, 2 * np.pi, 200)
xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi

n_try = 5

xts = np.zeros((n_try, 200, 500))
for i in range(n_try):
    for k, mu in enumerate(mu_targets):
        # np.random.vonmises deals with data on [-pi, pi[
        xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
        xts[i, k] = xt

# Put data on S^1=[0,1[
xts2 = xts / (2 * np.pi)
xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)

L_w2_circle = np.zeros((n_try, 200))
L_w2 = np.zeros((n_try, 200))

for i in range(n_try):
    w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
    w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)

    L_w2_circle[i] = w2_circle
    L_w2[i] = w2

m_w2_circle = np.mean(L_w2_circle, axis=0)
std_w2_circle = np.std(L_w2_circle, axis=0)

m_w2 = np.mean(L_w2, axis=0)
std_w2 = np.std(L_w2, axis=0)

pl.figure(1)
pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
pl.fill_between(
    mu_targets / (2 * np.pi),
    m_w2_circle - 2 * std_w2_circle,
    m_w2_circle + 2 * std_w2_circle,
    alpha=0.5,
)
pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
pl.fill_between(
    mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5
)
pl.vlines(
    x=[mu1 / (2 * np.pi)],
    ymin=0,
    ymax=np.max(w2),
    linestyle="--",
    color="k",
    label=r"$\mu_{\mathrm{target}}$",
)
pl.legend()
pl.xlabel(r"$\mu_{\mathrm{source}}$")
pl.show()
plot compute wasserstein circle
/home/circleci/project/ot/lp/solver_1d.py:796: RuntimeWarning: divide by zero encountered in divide
  (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp)
/home/circleci/project/ot/lp/solver_1d.py:796: RuntimeWarning: invalid value encountered in divide
  (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp)

不同 kappa 下 von Mises 和均匀分布之间的 Wasserstein 距离

\(\kappa=0\) 时,冯·米斯斯分布是 \(S^1\) 上的均匀分布。

kappas = np.logspace(-5, 2, 100)
n_try = 20

xts = np.zeros((n_try, 100, 500))
for i in range(n_try):
    for k, kappa in enumerate(kappas):
        # np.random.vonmises deals with data on [-pi, pi[
        xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
        xts[i, k] = xt / (2 * np.pi)

L_w2 = np.zeros((n_try, 100))
for i in range(n_try):
    L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)

m_w2 = np.mean(L_w2, axis=0)
std_w2 = np.std(L_w2, axis=0)

pl.figure(1)
pl.plot(kappas, m_w2)
pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
pl.xlabel(r"$\kappa$")
pl.show()
Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$

脚本的总运行时间: (0分钟 3.494秒)

由 Sphinx-Gallery 生成的画廊