注意
跳转到末尾 以下载完整示例代码。
使用 PyTorch 进行 Wasserstein 解混合
在这个例子中,我们估计从最小化沃瑟斯坦距离的分布中混合参数。换句话说,我们假设目标分布 \(\mu^t\) 可以表示为源分布 \(\mu^s_k\) 的加权和,模型如下:
\[\mu^t = \sum_{k=1}^K w_k\mu^s_k\]
其中 \(\mathbf{w}\) 是大小为 \(K\) 的向量,并属于分布单纯形 \(\Delta_K\)。
为了估计这个权重向量,我们建议优化模型与观察到的 \(\mu^t\) 之间的Wasserstein距离。这样就形成了以下优化问题:
\[\min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)\]
在这个例子中,这个最小化是通过在PyTorch中使用简单的投影梯度下降来完成的。我们使用POT的自动后端,它允许我们计算Wasserstein距离,使用ot.emd2和可微损失。
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
import torch
生成数据
nt = 100
nt1 = 10 #
ns1 = 50
ns = 2 * ns1
rng = np.random.RandomState(2)
xt = rng.randn(nt, 2) * 0.2
xt[:nt1, 0] += 1
xt[nt1:, 1] += 1
xs1 = rng.randn(ns1, 2) * 0.2
xs1[:, 0] += 1
xs2 = rng.randn(ns1, 2) * 0.2
xs2[:, 1] += 1
xs = np.concatenate((xs1, xs2))
# Sample reweighting matrix H
H = np.zeros((ns, 2))
H[:ns1, 0] = 1 / ns1
H[ns1:, 1] = 1 / ns1
# each columns sums to 1 and has weights only for samples form the
# corresponding source distribution
M = ot.dist(xs, xt)
绘制数据

<matplotlib.legend.Legend object at 0x76e48a5b1c30>
关于Wasserstein距离的模型优化
# convert numpy arrays to torch tensors
H2 = torch.tensor(H)
M2 = torch.tensor(M)
# weights for the source distributions
w = torch.tensor(ot.unif(2), requires_grad=True)
# uniform weights for target
b = torch.tensor(ot.unif(nt))
lr = 2e-3 # learning rate
niter = 500 # number of iterations
losses = [] # loss along the iterations
# loss for the minimal Wasserstein estimator
def get_loss(w):
a = torch.mv(H2, w) # distribution reweighting
return ot.emd2(a, b, M2) # squared Wasserstein 2
for i in range(niter):
loss = get_loss(w)
losses.append(float(loss))
loss.backward()
with torch.no_grad():
w -= lr * w.grad # gradient step
w[:] = ot.utils.proj_simplex(w) # projection on the simplex
w.grad.zero_()
目标的估计权重和收敛

Estimated mixture: [0.09980706 0.90019294]
Text(0.5, 23.52222222222222, 'Iterations')
绘制重加权的源分布
pl.figure(3)
# compute source weights
ws = H.dot(we)
pl.scatter(xt[:, 0], xt[:, 1], label="Target $\mu^t$", alpha=0.5)
pl.scatter(
xs[:, 0],
xs[:, 1],
color="C3",
s=ws * 20 * ns,
label="Weighted sources $\sum_{k} w_k\mu^s_k$",
alpha=0.5,
)
pl.title("Target and reweighted source distributions")
pl.legend()

<matplotlib.legend.Legend object at 0x76e48a594eb0>
脚本的总运行时间: (0 分钟 1.351 秒)