使用 PyTorch 的 Wasserstein 1D(流和重心)

在这个小例子中,我们考虑以下最小化问题:

\[\mu^* = \min_\mu W(\mu,\nu)\]

其中 \(\nu\) 是一个参考的1D度量。该问题通过投影梯度下降法处理,其中梯度由pyTorch自动微分计算。对简单形的投影确保迭代将保持在概率简单形上。

这个例子展示了 wasserstein_1d 函数和在 POT 框架内的后端使用。

# Author: Nicolas Courty <ncourty@irisa.fr>
#         Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import matplotlib as mpl
import torch

from ot.lp import wasserstein_1d
from ot.datasets import make_1D_gauss as gauss
from ot.utils import proj_simplex

red = np.array(mpl.colors.to_rgb("red"))
blue = np.array(mpl.colors.to_rgb("blue"))


n = 100  # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a = gauss(n, m=20, s=5)  # m= mean, s= std
b = gauss(n, m=60, s=10)

# enforce sum to one on the support
a = a / a.sum()
b = b / b.sum()


device = "cuda" if torch.cuda.is_available() else "cpu"

# use pyTorch for our data
x_torch = torch.tensor(x).to(device=device)
a_torch = torch.tensor(a).to(device=device).requires_grad_(True)
b_torch = torch.tensor(b).to(device=device)

lr = 1e-6
nb_iter_max = 800

loss_iter = []

pl.figure(1, figsize=(8, 4))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")

for i in range(nb_iter_max):
    # Compute the Wasserstein 1D with torch backend
    loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)
    # record the corresponding loss value
    loss_iter.append(loss.clone().detach().cpu().numpy())
    loss.backward()

    # performs a step of projected gradient descent
    with torch.no_grad():
        grad = a_torch.grad
        a_torch -= a_torch.grad * lr  # step
        a_torch.grad.zero_()
        a_torch.data = proj_simplex(a_torch)  # projection onto the simplex

    # plot one curve every 10 iterations
    if i % 10 == 0:
        mix = float(i) / nb_iter_max
        pl.plot(
            x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red
        )

pl.legend()
pl.title("Distribution along the iterations of the projected gradient descent")
pl.show()

pl.figure(2)
pl.plot(range(nb_iter_max), loss_iter, lw=3)
pl.title("Evolution of the loss along iterations", fontsize=16)
pl.show()
  • Distribution along the iterations of the projected gradient descent
  • Evolution of the loss along iterations
/home/circleci/project/ot/lp/solver_1d.py:41: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3683.)
  cws = cws.T.contiguous()

瓦瑟斯坦重心

在这个例子中,我们考虑以下Wasserstein重心问题 $$ \eta^* = \min_\eta;;; (1-t)W(\mu,\eta) + tW(\eta,\nu)$$ 其中 \(\\mu\)\(\\nu\) 是参考的一维测度,\(t\) 是一个参数 \(\in [0,1]\)。该问题通过项目梯度下降方法来处理,其中梯度由pyTorch自动微分计算。 对简单形的投影确保迭代将保持在概率简单形上。

此示例说明了wasserstein_1d函数和在POT框架内的后端使用。

device = "cuda" if torch.cuda.is_available() else "cpu"

# use pyTorch for our data
x_torch = torch.tensor(x).to(device=device)
a_torch = torch.tensor(a).to(device=device)
b_torch = torch.tensor(b).to(device=device)
bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)


lr = 1e-6
nb_iter_max = 2000

loss_iter = []

# instant of the interpolation
t = 0.5

for i in range(nb_iter_max):
    # Compute the Wasserstein 1D with torch backend
    loss = (1 - t) * wasserstein_1d(
        x_torch, x_torch, a_torch.detach(), bary_torch, p=2
    ) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)
    # record the corresponding loss value
    loss_iter.append(loss.clone().detach().cpu().numpy())
    loss.backward()

    # performs a step of projected gradient descent
    with torch.no_grad():
        grad = bary_torch.grad
        bary_torch -= bary_torch.grad * lr  # step
        bary_torch.grad.zero_()
        bary_torch.data = proj_simplex(bary_torch)  # projection onto the simplex

pl.figure(3, figsize=(8, 4))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c="green", label="W barycenter")
pl.legend()
pl.title("Wasserstein barycenter computed by gradient descent")
pl.show()

pl.figure(4)
pl.plot(range(nb_iter_max), loss_iter, lw=3)
pl.title("Evolution of the loss along iterations", fontsize=16)
pl.show()
  • Wasserstein barycenter computed by gradient descent
  • Evolution of the loss along iterations

脚本总运行时间: (0分钟 3.134秒)

由 Sphinx-Gallery 生成的画廊