具有分解耦合的最优传输

二维经验分布之间因子耦合OT的示例

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot

生成数据并绘制图形

# parameters and data generation

np.random.seed(42)

n = 100  # nb samples

xs = np.random.rand(n, 2) - 0.5

xs = xs + np.sign(xs)

xt = np.random.rand(n, 2) - 0.5

a, b = ot.unif(n), ot.unif(n)  # uniform distribution on samples
pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples")
pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples")
pl.legend(loc=0)
pl.title("Source and target distributions")
Source and target distributions
Text(0.5, 1.0, 'Source and target distributions')

计算分解的 OT 和精确的 OT 解

M = ot.dist(xs, xt)
G0 = ot.emd(a, b, M)

绘制因素化最优传输和精确最优传输解决方案

pl.figure(2, (14, 4))

pl.subplot(1, 3, 1)
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[0.2, 0.2, 0.2], alpha=0.1)
pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples")
pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples")
pl.title("Exact OT with samples")

pl.subplot(1, 3, 2)
ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[0.6, 0.6, 0.9], alpha=0.5)
ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[0.9, 0.6, 0.6], alpha=0.5)
pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples")
pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples")
pl.plot(xb[:, 0], xb[:, 1], "og", label="Template samples")
pl.title("Factored OT with template samples")

pl.subplot(1, 3, 3)
ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[0.2, 0.2, 0.2], alpha=0.1)
pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples")
pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples")
pl.title("Factored OT low rank OT plan")
Exact OT with samples, Factored OT with template samples, Factored OT low rank OT plan
Text(0.5, 1.0, 'Factored OT low rank OT plan')

脚本的总运行时间: (0分钟2.367秒)

由 Sphinx-Gallery 生成的画廊