使用Aim跟踪和比较GAN模型
概述
Generative Adversarial Networks,简称GANs,是基于深度学习的生成模型。
生成式建模是机器学习中的一种无监督学习任务,它涉及自动发现并学习输入数据的模式,使得模型能够生成看起来可能源自原始数据集的新样本。
在本指南中,我们将向您展示如何将Aim与您的GAN以及带EMA的GAN集成,通过比较两个实验生成的图像来评估它们的性能表现。
实验
我们将训练并比较常规GAN与采用EMA技术的GAN。EMA是一种在GAN训练中用于参数平均化的技术,它计算权重的指数折扣总和。
我们将使用由lucidrains实现的lightweight-gan模型,并以MetFaces Dataset作为训练数据集。
为了能够分析结果,我们将固定64个随机点,并在训练过程中同时跟踪常规GAN和带EMA的GAN。
使用Aim追踪图像
在训练器类中初始化一个新的运行以收集和存储图像序列:
class Trainer():
def __init__(
self,
name = 'default',
results_dir = 'results',
models_dir = 'models',
...
):
...
self.run = aim.Run() # Initialize aim.Run
self.run['hparams'] = hparams # Log hyperparams
...
代码托管在 GitHub
追踪常规GAN生成的图像:
# Regular GAN
# Get generated images
generated_images = self.generate_(self.GAN.G, latents)
aim_images = []
for idx, image in enumerate(generated_images):
ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = PIL.Image.fromarray(ndarr)
aim_images.append(aim.Image(im, caption=f'#{idx}'))
# Store with Aim (name="generated" and context.ema=0)
self.run.track(value=aim_images, name='generated', step=self.steps, context={'ema': False})
代码托管在 GitHub
追踪由启用EMA的GAN生成的图像:
# GAN with moving averages
# Get generated images
generated_images = self.generate_(self.GAN.GE, latents)
aim_images = []
for idx, image in enumerate(generated_images):
ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = PIL.Image.fromarray(ndarr)
aim_images.append(aim.Image(im, caption=f'EMA #{idx}'))
# Store with Aim (name="generated" and context.ema=1)
self.run.track(value=aim_images, name='generated', step=self.steps, context={'ema': True})
代码托管在 GitHub
通过Aim UI探索结果
可视化由常规GAN生成的图像:


可视化由GAN与EMA生成的图像:


您可能会注意到,使用EMA的GAN以指数方式收敛,最终效果更佳。
让我们将两种方法的最后一步进行并排比较:



结论
如您所见,采用EMA的GAN表现明显优于常规版本。
使用Aim可以轻松比较不同运行中跟踪图像的不同组别。
按运行哈希值分组,其他参数可用于切片和切块,观察不同运行之间的差异。