使用Aim跟踪和比较GAN模型

概述

Generative Adversarial Networks,简称GANs,是基于深度学习的生成模型。

生成式建模是机器学习中的一种无监督学习任务,它涉及自动发现并学习输入数据的模式,使得模型能够生成看起来可能源自原始数据集的新样本。

在本指南中,我们将向您展示如何将Aim与您的GAN以及带EMA的GAN集成,通过比较两个实验生成的图像来评估它们的性能表现。

实验

我们将训练并比较常规GAN与采用EMA技术的GAN。EMA是一种在GAN训练中用于参数平均化的技术,它计算权重的指数折扣总和。

我们将使用由lucidrains实现的lightweight-gan模型,并以MetFaces Dataset作为训练数据集。

为了能够分析结果,我们将固定64个随机点,并在训练过程中同时跟踪常规GAN和带EMA的GAN。

使用Aim追踪图像

  1. 在训练器类中初始化一个新的运行以收集和存储图像序列:

class Trainer():
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
                ...
        ):
    ...
    self.run = aim.Run()           # Initialize aim.Run
    self.run['hparams'] = hparams  # Log hyperparams
    ...

代码托管在 GitHub

  1. 追踪常规GAN生成的图像:

# Regular GAN

# Get generated images
generated_images = self.generate_(self.GAN.G, latents)

aim_images = []
for idx, image in enumerate(generated_images):
    ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = PIL.Image.fromarray(ndarr)
    aim_images.append(aim.Image(im, caption=f'#{idx}'))

# Store with Aim (name="generated" and context.ema=0)
self.run.track(value=aim_images, name='generated', step=self.steps, context={'ema': False})

代码托管在 GitHub

  1. 追踪由启用EMA的GAN生成的图像:

# GAN with moving averages

# Get generated images
generated_images = self.generate_(self.GAN.GE, latents)

aim_images = []
for idx, image in enumerate(generated_images):
    ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = PIL.Image.fromarray(ndarr)
    aim_images.append(aim.Image(im, caption=f'EMA #{idx}'))

# Store with Aim (name="generated" and context.ema=1)
self.run.track(value=aim_images, name='generated', step=self.steps, context={'ema': True})

代码托管在 GitHub

通过Aim UI探索结果

  1. 可视化由常规GAN生成的图像:

  1. 可视化由GAN与EMA生成的图像:

您可能会注意到,使用EMA的GAN以指数方式收敛,最终效果更佳。

  1. 让我们将两种方法的最后一步进行并排比较:

结论

如您所见,采用EMA的GAN表现明显优于常规版本。

使用Aim可以轻松比较不同运行中跟踪图像的不同组别。

按运行哈希值分组,其他参数可用于切片和切块,观察不同运行之间的差异。