1. 其他教程
  2. 使用GAN创建你自己的朋友

使用GAN创建你自己的朋友

介绍

似乎加密货币、NFTs和web3运动最近都非常流行!数字资产在市场上以惊人的金额上市,几乎每个名人都在推出自己的NFT系列。虽然你的加密资产可能是应税的,比如在加拿大,但今天我们将探索一些有趣且免税的方法来生成你自己的一系列程序生成的CryptoPunks

生成对抗网络,通常简称为GANs,是一类特定的深度学习模型,旨在从输入数据集中学习,以创建(生成!)与原始训练集元素相似的新材料。著名的网站thispersondoesnotexist.com因使用名为StyleGAN2的模型生成逼真但合成的人物图像而走红。GANs在机器学习领域获得了广泛关注,现在被用于生成各种图像、文本,甚至音乐

今天我们将简要了解GANs背后的高级直觉,然后我们将围绕一个预训练的GAN构建一个小演示,看看这一切的喧嚣是什么。这里是我们将要组装的内容的预览

先决条件

确保你已经安装gradio Python包。要使用预训练模型,还需要安装torchtorchvision

GANs: 一个非常简短的介绍

最初由Goodfellow等人在2014年提出,GANs由神经网络组成,这些神经网络相互竞争,目的是智胜对方。一个网络,称为生成器,负责生成图像。另一个网络,判别器,每次从生成器接收一个图像以及来自训练数据集的真实图像。然后判别器必须猜测:哪个图像是假的?

生成器不断训练以创建对判别器来说更难识别的图像,而判别器每次正确检测到假图像时都会提高生成器的标准。随着网络参与这种竞争(对抗性!)关系,生成的图像会改进到人类眼睛无法区分的程度!

要更深入地了解GANs,您可以查看Analytics Vidhya上的这篇优秀文章或这个PyTorch教程。不过现在,我们将深入一个演示!

步骤1 — 创建生成器模型

要使用GAN生成新图像,您只需要生成器模型。生成器可以使用许多不同的架构,但在本演示中,我们将使用一个预训练的GAN生成器模型,其架构如下:

from torch import nn

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

我们正在从这个由@teddykoker提供的仓库中获取生成器,你也可以在那里看到原始的判别器模型结构。

在实例化模型之后,我们将从Hugging Face Hub加载权重,存储在nateraw/cryptopunks-gan中:

from huggingface_hub import hf_hub_download
import torch

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

步骤 2 — 定义一个 predict 函数

predict 函数是使 Gradio 工作的关键!无论我们通过 Gradio 界面选择什么输入,这些输入都将传递给我们的 predict 函数,该函数应对输入进行操作并生成我们可以使用 Gradio 输出组件显示的输出。对于 GANs,通常将随机噪声作为输入传递给我们的模型,因此我们将生成一个随机数张量并将其传递给模型。然后我们可以使用 torchvisionsave_image 函数将模型的输出保存为 png 文件,并返回文件名:

from torchvision.utils import save_image

def predict(seed):
    num_punks = 4
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

我们给predict函数添加了一个seed参数,这样我们就可以通过种子固定随机张量的生成。然后,如果我们想再次看到它们,通过传入相同的种子,我们将能够重现punks。

注意! 我们的模型需要一个维度为100x1x1的输入张量来进行单次推理,或者(BatchSize)x100x1x1来生成一批图像。在这个演示中,我们将从一次生成4个朋克开始。

步骤 3 — 创建 Gradio 界面

此时,你甚至可以运行你拥有的代码predict(),你会在文件系统中的./punks.png找到你刚刚生成的朋克。不过,为了制作一个真正互动的演示,我们将使用Gradio构建一个简单的界面。我们的目标是:

  • 设置一个滑块输入,以便用户可以选择“seed”值
  • 使用图像组件来展示生成的punks
  • 使用我们的 predict() 来获取种子并生成图像

使用gr.Interface(),我们可以通过一个函数调用来定义所有这些内容:

import gradio as gr

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
    ],
    outputs="image",
).launch()

步骤 4 — 更多的朋克!

一次生成4个朋克是一个好的开始,但也许我们想控制每次生成的数量。向我们的Gradio界面添加更多输入就像向传递给gr.Interfaceinputs列表中添加另一个项目一样简单:

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10), # Adding another slider!
    ],
    outputs="image",
).launch()

新的输入将被传递到我们的predict()函数中,因此我们必须对该函数进行一些更改以接受一个新参数:

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

当你重新启动界面时,你应该会看到第二个滑块,它可以让你控制punk的数量!

第5步 - 完善它

您的Gradio应用程序已经差不多准备好了,但您可以添加一些额外的东西,让它真正准备好迎接聚光灯 ✨

我们可以添加一些示例,用户可以通过将其添加到gr.Interface中来轻松尝试:

gr.Interface(
    # ...
    # keep everything as it is, and then add
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True) # cache_examples is optional

examples 参数接受一个列表的列表,其中子列表中的每个项目都按照我们列出的inputs的顺序排列。所以在我们的例子中,[seed, num_punks]。试试看吧!

你也可以尝试向gr.Interface添加titledescriptionarticle。这些参数都接受字符串,所以试试看会发生什么 👀 article 也接受 HTML,正如在之前的指南中探讨的

当你全部完成后,你可能会得到像this这样的东西。

作为参考,这是我们完整的代码:

import torch
from torch import nn
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
import gradio as gr

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10),
    ],
    outputs="image",
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True)

恭喜!你已经构建了自己的基于GAN的CryptoPunks生成器,并配备了一个华丽的Gradio界面,使任何人都能轻松使用。现在你可以在Hub上搜索更多GAN(或训练你自己的),并继续制作更多令人惊叹的演示 🤗