注意
跳转到末尾 以下载完整示例代码。
使用PyTorch的Wasserstein 2小批量GAN
在这个例子中,我们使用Wasserstein 2在小批量上训练Wasserstein GAN,作为一个分布拟合项。
我们想要训练一个生成器 \(G_\theta\),它可以从高斯 \(\mu_n\) 分布中抽取的随机噪声生成真实的数据,使得生成的数据与数据分布中的真实数据 \(\mu_d\) 区分不开。为此,Wasserstein GAN [Arjovsky2017] 旨在通过以下优化问题来优化生成器的参数 \(\theta\):
在实践中,我们无法获得完整的分布\(\mu_d\),而只有样本,并且我们不能计算大数据集的Wasserstein距离。[Arjovsky2017]提出用神经网络逼近Wasserstein 1的对偶势能,从而恢复一个类似于GAN的优化问题。在这个例子中,我们将在每次迭代中优化Wasserstein距离的期望值,通过小批量进行优化,如[Genevay2018]所提议的。Wasserstein距离的小批量优化在[Fatras2019]中进行了研究。
[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017年7月). 瓦瑟斯坦生成对抗网络. 在国际机器学习会议上 (第214-223页). PMLR.
[Genevay2018] Genevay, Aude, Gabriel Peyré, 和 Marco Cuturi。“利用Sinkhorn散度学习生成模型。”国际人工智能与统计会议。PMLR,2018。
[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N. (2020年6月)。使用小批量Wasserstein学习:渐近与梯度 性质。在第23届人工智能与统计国际会议上(第108卷)。
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
import numpy as np
import matplotlib.pyplot as pl
import matplotlib.animation as animation
import torch
from torch import nn
import ot
数据生成
torch.manual_seed(1)
sigma = 0.1
n_dims = 2
n_features = 2
def get_data(n_samples):
c = torch.rand(size=(n_samples, 1))
angle = c * 2 * np.pi
x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
x += torch.randn(n_samples, 2) * sigma
return x
绘制数据

<matplotlib.legend.Legend object at 0x76e47cb9c5b0>
生成器模型
# define the MLP model
class Generator(torch.nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(n_features, 200)
self.fc2 = nn.Linear(200, 500)
self.fc3 = nn.Linear(500, n_dims)
self.relu = torch.nn.ReLU() # instead of Heaviside step fn
def forward(self, x):
output = self.fc1(x)
output = self.relu(output) # instead of Heaviside step fn
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
训练模型
G = Generator()
optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
# number of iteration and size of the batches
n_iter = 200 # set to 200 for doc build but 1000 is better ;)
size_batch = 500
# generate statis samples to see their trajectory along training
n_visu = 100
xnvisu = torch.randn(n_visu, n_features)
xvisu = torch.zeros(n_iter, n_visu, n_dims)
ab = torch.ones(size_batch) / size_batch
losses = []
for i in range(n_iter):
# generate noise samples
xn = torch.randn(size_batch, n_features)
# generate data samples
xd = get_data(size_batch)
# generate sample along iterations
xvisu[i, :, :] = G(xnvisu).detach()
# generate samples and compte distance matrix
xg = G(xn)
M = ot.dist(xg, xd)
loss = ot.emd2(ab, ab, M)
losses.append(float(loss.detach()))
if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
loss.backward()
optimizer.step()
optimizer.zero_grad()
del M
pl.figure(2)
pl.semilogy(losses)
pl.grid()
pl.title("Wasserstein distance")
pl.xlabel("Iterations")

Iter: 0, loss=0.9009847640991211
Iter: 10, loss=0.10964284837245941
Iter: 20, loss=0.04564394801855087
Iter: 30, loss=0.03516071289777756
Iter: 40, loss=0.05013977363705635
Iter: 50, loss=0.058588504791259766
Iter: 60, loss=0.03730057179927826
Iter: 70, loss=0.04171676188707352
Iter: 80, loss=0.03168988972902298
Iter: 90, loss=0.031197285279631615
Iter: 100, loss=0.03596879169344902
Iter: 110, loss=0.03272819146513939
Iter: 120, loss=0.032379165291786194
Iter: 130, loss=0.03959248960018158
Iter: 140, loss=0.029337508603930473
Iter: 150, loss=0.05796702206134796
Iter: 160, loss=0.034939464181661606
Iter: 170, loss=0.022607704624533653
Iter: 180, loss=0.04347885772585869
Iter: 190, loss=0.1164197325706482
Text(0.5, 23.52222222222222, 'Iterations')
沿迭代绘制生成样本的轨迹
pl.figure(3, (10, 10))
ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
for i in range(9):
pl.subplot(3, 3, i + 1)
pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1)
pl.scatter(
xvisu[ivisu[i], :, 0],
xvisu[ivisu[i], :, 1],
label="Data samples from $G\#\mu_n$",
alpha=0.5,
)
pl.xticks(())
pl.yticks(())
pl.title("Iter. {}".format(ivisu[i]))
if i == 0:
pl.legend()

在迭代过程中动态显示生成样本的轨迹
pl.figure(4, (8, 8))
def _update_plot(i):
pl.clf()
pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1)
pl.scatter(
xvisu[i, :, 0], xvisu[i, :, 1], label="Data samples from $G\#\mu_n$", alpha=0.5
)
pl.xticks(())
pl.yticks(())
pl.xlim((-1.5, 1.5))
pl.ylim((-1.5, 1.5))
pl.title("Iter. {}".format(i))
return 1
i = 0
pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1)
pl.scatter(
xvisu[i, :, 0], xvisu[i, :, 1], label="Data samples from $G\#\mu_n$", alpha=0.5
)
pl.xticks(())
pl.yticks(())
pl.xlim((-1.5, 1.5))
pl.ylim((-1.5, 1.5))
pl.title("Iter. {}".format(ivisu[i]))
ani = animation.FuncAnimation(
pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000
)