注意
点击 这里 下载完整示例代码
随机示例
这个示例旨在展示如何使用POT库的离散和半连续测量的随机优化算法。
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. 大规模最优运输的随机优化。 神经信息处理系统进展 (2016).
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. 大规模最优运输与映射估计。 国际表征学习会议(2018)
# Author: Kilian Fatras <kilian.fatras@gmail.com>
#
# License: MIT License
import matplotlib.pylab as pl
import numpy as np
import ot
import ot.plot
计算半对偶问题的运输矩阵
离散情况
对离散情况的两个离散度量进行采样并计算它们的成本矩阵 c。
调用“SAG”方法以找到离散情况下的运输矩阵
method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
numItermax)
print(sag_pi)
[[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
[1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
[3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
[2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
[9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
[2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]
半连续案例
样本一一般测量 a,一个离散测量 b 适用于半连续情况,定义源测量和目标测量的点并计算成本矩阵。
调用“ASGD”方法以找到半连续情况下的运输矩阵。
[3.89210786 7.62897384 3.89245014 2.61724317 1.51339313 3.34708637
2.73931688] [-2.47771832 -2.44147638 -0.84136916 5.76056385]
[[2.56007346e-02 9.81885744e-02 1.90636347e-02 4.19914973e-06]
[1.21903709e-01 1.23580049e-02 1.40646856e-03 7.18896015e-03]
[3.47217135e-03 7.30299279e-02 6.63549167e-02 1.26850485e-07]
[2.51172810e-02 3.15791525e-02 8.57801775e-02 3.80531864e-04]
[1.00343023e-02 7.53482461e-04 1.18796723e-02 1.20189686e-01]
[2.21820738e-02 9.09237539e-04 2.06293608e-03 1.17702895e-01]
[4.01092095e-02 2.52599884e-02 7.52407360e-02 2.24720898e-03]]
将结果与Sinkhorn算法进行比较
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
绘制运输矩阵
针对SAG
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
pl.show()

对于ASGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
pl.show()

对于Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()

计算对偶问题的运输矩阵
半连续案例
样本一个一般测量a,一个离散测量b用于半连续情况,并计算成本矩阵c。
n_source = 7
n_target = 4
reg = 1
numItermax = 100000
lr = 0.1
batch_size = 3
log = True
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)
rng = np.random.RandomState(0)
X_source = rng.randn(n_source, 2)
Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
调用“SGD”双重方法以寻找半连续情况下的运输矩阵
sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
batch_size, numItermax,
lr, log=log)
print(log_sgd['alpha'], log_sgd['beta'])
print(sgd_dual_pi)
[0.92355578 2.77992798 1.07804471 0.02016319 0.60713159 1.81182019
0.11197572] [0.34336982 0.4710476 1.5721106 4.94609115]
[[2.20907087e-02 9.28385311e-02 1.09431124e-02 9.55464538e-08]
[1.60430180e-02 1.78208165e-03 1.23133751e-04 2.49477404e-05]
[3.49545252e-03 8.05588419e-02 4.44378765e-02 3.36736644e-09]
[3.14237585e-02 4.32908444e-02 7.13921359e-02 1.25537224e-05]
[6.80903869e-02 5.60249986e-03 5.36266927e-02 2.15061075e-02]
[8.02467634e-02 3.60423269e-03 4.96465510e-03 1.12281580e-02]
[4.86841219e-02 3.35959147e-02 6.07539060e-02 7.19254669e-05]]
将结果与Sinkhorn算法进行比较
从POT调用Sinkhorn算法
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
绘制运输矩阵
对于SGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
pl.show()

对于Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()

脚本的总运行时间: ( 0 分钟 7.635 秒)