随机示例

本例旨在演示如何使用POT库的离散和半连续测度的随机优化算法。

[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. 大规模最优运输的随机优化。 神经信息处理系统进展 (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. 大规模最优运输与映射估计。 国际表征学习会议(2018)

# Author: Kilian Fatras <kilian.fatras@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np
import ot
import ot.plot

计算半对偶问题的运输矩阵

离散情况

对离散情况的两个离散度量进行采样并计算它们的成本矩阵 c。

调用“SAG”方法以找到离散情况下的运输矩阵

[[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
 [1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
 [3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
 [2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
 [9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
 [2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]

半连续案例

示例一个一般测量 a,一个离散测量 b 用于半连续案例,定义源和目标测量的点并计算成本矩阵。

调用“ASGD”方法以在半连续情况下找到运输矩阵。

[3.76510592 7.64094845 3.78917596 2.57007572 1.65543745 3.4893295
 2.70623359] [-2.50319213 -2.25852474 -0.82688144  5.5885983 ]
[[2.19802712e-02 1.03838786e-01 1.70349712e-02 3.11402024e-06]
 [1.20269164e-01 1.50177118e-02 1.44418382e-03 6.12608330e-03]
 [3.05271739e-03 7.90868636e-02 6.07174656e-02 9.63289956e-08]
 [2.33574229e-02 3.61718564e-02 8.30222147e-02 3.05648858e-04]
 [1.12749105e-02 1.04283861e-03 1.38926617e-02 1.16646732e-01]
 [2.49295484e-02 1.25865775e-03 2.41297662e-03 1.14255960e-01]
 [3.78279732e-02 2.93440562e-02 7.38545201e-02 1.83059335e-03]]

将结果与Sinkhorn算法进行比较

[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
 [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
 [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
 [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
 [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
 [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]

绘制运输矩阵

针对SAG

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, "semi-dual : OT matrix SAG")
pl.show()
plot stochastic

对于ASGD

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, "semi-dual : OT matrix ASGD")
pl.show()
plot stochastic

对于Sinkhorn

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn")
pl.show()
plot stochastic

计算对偶问题的运输矩阵

半连续案例

样本一个一般测量a,一个离散测量b用于半连续情况,并计算成本矩阵c。

调用“SGD”双重方法以寻找半连续情况下的运输矩阵

[0.91732819 2.7799397  1.07406199 0.01970121 0.60717156 1.80910257
 0.10902398] [0.34639291 0.47463643 1.57482501 4.92047485]
[[2.20200322e-02 9.25938748e-02 1.09047347e-02 9.25518158e-08]
 [1.60917795e-02 1.78850969e-03 1.23469888e-04 2.43170724e-05]
 [3.49209980e-03 8.05271170e-02 4.43815515e-02 3.26915633e-09]
 [3.15043415e-02 4.34264205e-02 7.15531236e-02 1.22305749e-05]
 [6.82992713e-02 5.62286712e-03 5.37746045e-02 2.09630346e-02]
 [8.02712798e-02 3.60737409e-03 4.96463916e-03 1.09144850e-02]
 [4.86875958e-02 3.36173252e-02 6.07394894e-02 6.98997703e-05]]

将结果与Sinkhorn算法进行比较

从POT调用Sinkhorn算法

[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
 [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
 [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
 [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
 [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
 [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]

绘制运输矩阵

对于SGD

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, "dual : OT matrix SGD")
pl.show()
plot stochastic

对于Sinkhorn

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn")
pl.show()
plot stochastic

脚本的总运行时间: (0分钟 7.676秒)

由 Sphinx-Gallery 生成的画廊