This tutorial is available as a Jupyter Notebook! Download notebook from here.

CIFAR-10上的ResNet20:剪枝

提示

在Google Colab上大约需要3小时,但如果你使用更好的GPU,可以在1小时内完成。根据你运行此笔记本的系统,你可能会得到与下面报告的略有不同的准确率。此笔记本的目的是演示使用Model Optimizer进行剪枝的工作流程,而不是为了达到最佳准确率。

在本教程中,我们将使用模型优化器来使ResNet模型在我们的目标部署约束下更快,通过剪枝而不牺牲太多准确性!

在本教程结束时,您将:

  • 了解如何使用模型优化器修剪用户提供的模型,以找到最适合您目标部署约束的最佳子网架构。

  • 保存并恢复您的剪枝模型,以便用于下游任务,如微调和推理。

所有这些只需几行代码!是的,就是这么简单!

首先,按照安装步骤安装Model Optimizer

[ ]:
%pip install "nvidia-modelopt[torch]" --extra-index-url https://pypi.nvidia.com
[2]:
import math
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import BasicBlock
from tqdm.auto import tqdm

seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda")

CIFAR-10 图像分类数据集

在本教程中,我们将使用著名的CIFAR-10数据集进行图像分类。该数据集包含来自10个类别的60k张32x32图像,分为50k张训练图像和10k张测试图像。我们还将从训练集中随机抽取5k张图像作为我们的验证集。

[3]:
def get_cifar10_dataloaders(train_batch_size: int):
    """Return Train-Val-Test data loaders for the CIFAR-10 dataset."""
    np.random.seed(seed)

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])

    # Split Train dataset into Train-Val datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                normalize,
            ]
        ),
        download=True,
    )

    n_trainval = len(train_dataset)
    n_train = int(n_trainval * 0.9)
    ids = np.arange(n_trainval)
    np.random.shuffle(ids)
    train_ids, val_ids = ids[:n_train], ids[n_train:]

    train_dataset.data = train_dataset.data[train_ids]
    train_dataset.targets = np.array(train_dataset.targets)[train_ids]

    val_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]),
        download=True,
    )
    val_dataset.data = val_dataset.data[val_ids]
    val_dataset.targets = np.array(val_dataset.targets)[val_ids]

    test_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        transform=val_dataset.transform,
        download=True,
    )

    num_workers = min(8, os.cpu_count())
    train_loader = torch.utils.data.DataLoader(
        train_dataset, train_batch_size, num_workers=num_workers, pin_memory=True, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True
    )
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

用于CIFAR数据集的ResNet

我们将使用CIFAR数据集的ResNet变体,即ResNet-20和ResNet-32,因为这些是非常小的训练模型。您可以在这篇论文中找到有关这些模型的更多详细信息。下面是一个没有任何新内容的常规PyTorch模型的示例。

设置模型

我们首先设置并添加一些用于训练的辅助函数

[4]:
def _weights_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.kaiming_normal_(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class ResNet(nn.Module):
    def __init__(self, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self._make_layer(16, num_blocks, stride=1),
            self._make_layer(32, num_blocks, stride=2),
            self._make_layer(64, num_blocks, stride=2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, num_classes),
        )
        self.apply(_weights_init)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            downsample = None
            if stride != 1 or self.in_planes != planes:
                downsample = LambdaLayer(
                    lambda x: F.pad(
                        x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0
                    )
                )
            layers.append(BasicBlock(self.in_planes, planes, stride, downsample))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


def resnet20(ckpt=None):
    model = ResNet(num_blocks=3).to(device)
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt, device))
    return model


def resnet32(ckpt=None):
    model = ResNet(num_blocks=5).to(device)
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt, device))
    return model
[5]:
class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: int,
        decay_steps: int,
        warmup_lr: float = 0.0,
        last_epoch: int = -1,
    ) -> None:
        self.warmup_steps = warmup_steps
        self.warmup_lr = warmup_lr
        self.decay_steps = decay_steps
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [
                (base_lr - self.warmup_lr) * self.last_epoch / self.warmup_steps + self.warmup_lr
                for base_lr in self.base_lrs
            ]
        else:
            current_steps = self.last_epoch - self.warmup_steps
            return [
                0.5 * base_lr * (1 + math.cos(math.pi * current_steps / self.decay_steps))
                for base_lr in self.base_lrs
            ]


def get_optimizer_scheduler(model, lr, weight_decay, warmup_steps, decay_steps):
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr,
        momentum=0.9,
        weight_decay=weight_decay,
    )
    lr_scheduler = CosineLRwithWarmup(optimizer, warmup_steps, decay_steps)
    return optimizer, lr_scheduler


def train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler):
    """Train the given model for 1 epoch."""
    model.train()
    epoch_loss = 0.0
    for imgs, labels in train_loader:
        output = model(imgs.to(device))
        loss = loss_fn(model, output, labels.to(device))
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    epoch_loss /= len(train_loader.dataset)
    return epoch_loss


@torch.no_grad()
def evaluate(model, test_loader):
    """Evaluate the model on the given test_loader and return accuracy percentage."""
    model.eval()
    correct = total = 0.0
    for imgs, labels in test_loader:
        output = model(imgs.to(device))
        predicted = output.argmax(dim=1).detach().cpu()
        correct += torch.sum(labels == predicted).item()
        total += len(labels)

    accuracy = 100 * correct / total
    return accuracy


def loss_fn_default(model, output, labels):
    return F.cross_entropy(output, labels)


def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
    loss_fn=loss_fn_default,
    print_freq=25,
    ckpt_path="temp_saved_model.pth",
):
    """Train the given model with provided parameters.

    loss_fn: function that takes model, output, labels and returns loss. This allows us to obtain the loss
        from the model as well if needed.
    """
    best_val_acc, best_ep = 0.0, 0
    print(f"Training the model for {num_epochs} epochs...")
    for ep in tqdm(range(1, num_epochs + 1)):
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler)

        val_acc = evaluate(model, val_loader)
        if val_acc >= best_val_acc:
            best_val_acc, best_ep = val_acc, ep
            torch.save(model.state_dict(), ckpt_path)

        if ep == 1 or ep % print_freq == 0 or ep == num_epochs:
            print(f"Epoch {ep:3d}\t Training loss: {train_loss:.4f}\t Val Accuracy: {val_acc:.2f}%")

    print(
        f"Model Trained! Restoring to parameters that gave best Val Accuracy ({best_val_acc:.2f}% at Epoch {best_ep})."
    )
    model.load_state_dict(torch.load(ckpt_path), device)

您可以取消下面打印语句的注释,以查看ResNet20模型的详细信息。

[6]:
# print(resnet20())

训练基线模型

根据您的GPU和CPU,训练大约需要10-30分钟。与论文中描述的原始设置相比,我们使用了稍微不同的训练超参数,以使本教程的训练速度更快。

你也可以减少下面的num_epochs参数,以牺牲准确性为代价使整个笔记本运行得更快。

[7]:
batch_size = 512
num_epochs = 120
learning_rate = 0.1 * batch_size / 128
weight_decay = 1e-4

train_loader, val_loader, test_loader = get_cifar10_dataloaders(batch_size)

batch_per_epoch = len(train_loader)
warmup_steps = 5 * batch_per_epoch
decay_steps = num_epochs * batch_per_epoch
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train: 45000, Val: 5000, Test: 10000
[8]:
resnet20_model = resnet20()
optimizer, lr_scheduler = get_optimizer_scheduler(
    resnet20_model, learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
    resnet20_model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
    ckpt_path="resnet20.pth",
)
print(f"Test Accuracy of ResNet20: {evaluate(resnet20_model, test_loader)}")
Training the model for 120 epochs...
Epoch   1        Training loss: 0.0049   Val Accuracy: 22.82%
Epoch  25        Training loss: 0.0006   Val Accuracy: 78.84%
Epoch  50        Training loss: 0.0004   Val Accuracy: 85.06%
Epoch  75        Training loss: 0.0002   Val Accuracy: 88.12%
Epoch 100        Training loss: 0.0001   Val Accuracy: 90.34%
Epoch 120        Training loss: 0.0000   Val Accuracy: 90.80%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.92% at Epoch 119).
Test Accuracy of ResNet20: 90.97

我们现在已经建立了一个基线模型和准确率,我们将使用模型优化器进行比较。

到目前为止,我们已经看到了一个普通的PyTorch模型,没有使用任何新技术进行训练。现在,让我们使用模型优化器为我们的目标约束优化模型!

使用模型优化器进行FastNAS剪枝

模型优化器的modelopt.torch.prune模块提供了先进的、最先进的剪枝算法,使您能够从提供的基础模型中搜索最佳的子网架构。

模型优化器可以用于以下互补模式之一,以创建用于优化模型的搜索空间:

  1. fastnas: 一种推荐用于计算机视觉模型的剪枝方法。给定一个预训练模型,FastNAS 找到在满足给定约束条件的同时最大化评分函数的子网络。

  2. mcore_gpt_minitron: 一种由NVIDIA Research开发的剪枝方法,用于在NVIDIA NeMo或Megatron-LM框架中使用管道并行性的GPT风格模型。它使用激活幅度来剪枝mlp、注意力头和GQA查询组。

  3. gradnas: 一种轻量级的剪枝方法,推荐用于像Hugging Face BERT、GPT-J这样的语言模型。它使用梯度信息来剪枝模型的线性层和注意力头,以满足给定的约束条件。

在这个例子中,我们将使用fastnas模式来为CIFAR-10数据集修剪ResNet20模型。查看Model Optimizer GitHub仓库以获取更多示例。

首先,我们使用FastNAS模式来转换ResNet模型,并减少其FLOPs、参数数量和延迟。

[9]:
import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp

修剪基础模型并存储修剪后的网络

使用 mtp.prune 你可以

  • 从你的基础模型中生成一个用于剪枝的搜索空间;

  • 修剪模型;

  • 获取一个可用于微调的有效pytorch模型。

假设你有ResNet20模型作为我们的基础模型进行剪枝,并且我们正在寻找一个最多有30M FLOPs的模型。我们可以为flops和/或params提供一个上限的搜索约束。这些值可以是绝对数字(例如30e6)或字符串百分比(例如"75%")。此外,我们还应该提供我们的训练数据加载器给mtp.prune。训练数据加载器将用于校准模型中的归一化层。最后,我们还将指定一个自定义配置来配置剪枝搜索空间,以获得更细粒度的剪枝网络选择。

最后,我们可以使用mto.save来存储修剪后的架构和权重。

注意

我们在这里优化一个相对较小的模型。在这种情况下,更细粒度的搜索可能会更有效。这就是为什么我们要指定自定义配置。然而,通常建议使用默认配置本身来转换模型。

[10]:
# config with more fine-grained channel choices for fastnas
config = mtp.config.FastNASConfig()
config["nn.Conv2d"]["*"]["channel_divisor"] = 16
config["nn.BatchNorm2d"]["*"]["feature_divisor"] = 16

# A single 32x32 image for computing FLOPs
dummy_input = torch.randn(1, 3, 32, 32, device=device)


# Wrap your original validation function to only take the model as input.
# This function acts as the score function to rank models.
def score_func(model):
    return evaluate(model, val_loader)


# prune the model
pruned_model, _ = mtp.prune(
    model=resnet20(ckpt="resnet20.pth"),
    mode=[("fastnas", config)],
    constraints={"flops": 30e6},
    dummy_input=dummy_input,
    config={
        "data_loader": train_loader,
        "score_func": score_func,
        "checkpoint": "modelopt_seaarch_checkpoint_fastnas.pth",
    },
)

# save the pruned model for future use
mto.save(pruned_model, "modelopt_pruned_model_fastnas.pth")

# evaluate the pruned model
print(f"Test Accuracy of Pruned ResNet20: {evaluate(pruned_model, test_loader)}")

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------

                              Profiling Results
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint    min           centroid      max           max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 24.33M       │ 27.57M       │ 40.55M       │ 1.67          │
│ params       │ 90.94K       │ 141.63K      │ 268.35K      │ 2.95          │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘

            Constraints Evaluation
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃                             Satisfiable  ┃
┃ Constraint    Upper Bound   Upper Bound  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops        │ 30.00M       │ True         │
└──────────────┴──────────────┴──────────────┘


Search Space Summary:
----------------------------------------------------------------------------------------------------
  layers.depth                                                                     [9]
  layers.0.out_channels                                                            [16]
  layers.0.in_channels                                                             [3]
  layers.3.depth                                                                   [3]
  layers.3.0.conv1.out_channels                                                    [16]
  layers.3.0.conv1.in_channels                                                     [16]
  layers.3.0.bn1.num_features                                                      [16]
  layers.3.0.conv2.out_channels                                                    [16]
  layers.3.0.conv2.in_channels                                                     [16]
  layers.3.1.conv1.out_channels                                                    [16]
  layers.3.1.conv1.in_channels                                                     [16]
  layers.3.1.bn1.num_features                                                      [16]
  layers.3.1.conv2.out_channels                                                    [16]
  layers.3.1.conv2.in_channels                                                     [16]
  layers.3.2.conv1.out_channels                                                    [16]
  layers.3.2.conv1.in_channels                                                     [16]
  layers.3.2.bn1.num_features                                                      [16]
  layers.3.2.conv2.out_channels                                                    [16]
  layers.3.2.conv2.in_channels                                                     [16]
  layers.4.depth                                                                   [3]
* layers.4.0.conv1.out_channels                                                    [16, 32]
  layers.4.0.conv1.in_channels                                                     [16]
  layers.4.0.bn1.num_features                                                      [16, 32]
  layers.4.0.conv2.out_channels                                                    [32]
  layers.4.0.conv2.in_channels                                                     [16, 32]
* layers.4.1.conv1.out_channels                                                    [16, 32]
  layers.4.1.conv1.in_channels                                                     [32]
  layers.4.1.bn1.num_features                                                      [16, 32]
  layers.4.1.conv2.out_channels                                                    [32]
  layers.4.1.conv2.in_channels                                                     [16, 32]
* layers.4.2.conv1.out_channels                                                    [16, 32]
  layers.4.2.conv1.in_channels                                                     [32]
  layers.4.2.bn1.num_features                                                      [16, 32]
  layers.4.2.conv2.out_channels                                                    [32]
  layers.4.2.conv2.in_channels                                                     [16, 32]
  layers.5.depth                                                                   [3]
* layers.5.0.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.0.conv1.in_channels                                                     [32]
  layers.5.0.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.0.conv2.out_channels                                                    [64]
  layers.5.0.conv2.in_channels                                                     [16, 32, 48, 64]
* layers.5.1.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.1.conv1.in_channels                                                     [64]
  layers.5.1.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.1.conv2.out_channels                                                    [64]
  layers.5.1.conv2.in_channels                                                     [16, 32, 48, 64]
* layers.5.2.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.2.conv1.in_channels                                                     [64]
  layers.5.2.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.2.conv2.out_channels                                                    [64]
  layers.5.2.conv2.in_channels                                                     [16, 32, 48, 64]
----------------------------------------------------------------------------------------------------
Number of configurable hparams: 6
Total size of the search space: 5.12e+02
Note: all constraints can be satisfied within the search space!


Beginning pre-search estimation. If the runtime of score function is longer than a few minutes, consider subsampling the dataset used in score function.
A PyTorch dataset can be subsampled using torch.utils.data.Subset (https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) as following:
 subset_dataset = torch.utils.data.Subset(dataset, indices)
Collecting pre-search statistics: 100%|██████████| 18/18 [00:10<00:00,  1.76it/s, cur=layers.5.2.conv1.out_channels(64/64): 0.00]
[num_satisfied] = 11:   0%|          | 20/10000 [00:02<17:43,  9.39it/s]
[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}

Test Accuracy of Pruned ResNet20: 60.37

正如我们所看到的,最佳子网(29.6M FLOPs)符合我们30M FLOPs的约束条件。我们还可以看到搜索模型的验证准确率有所下降。这在剪枝后非常常见,因此需要对该模型进行微调。

使用mto.restore恢复被修剪的子网

[11]:
pruned_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas.pth")

微调

为了微调子网,您可以简单地重复原始模型的训练流程(1倍训练时间,0.5倍到1倍的原始学习率)。微调后的模型构成了最终模型,该模型在准确性和您提供的约束之间实现了最佳权衡,并用于部署。

请注意,根据您的GPU和CPU,训练可能需要大约5到15分钟。

[12]:
optimizer, lr_scheduler = get_optimizer_scheduler(
    pruned_model, 0.5 * learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
    pruned_model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
)
# store final model
mto.save(pruned_model, "modelopt_pruned_model_fastnas_finetuned.pth")
Training the model for 120 epochs...
Epoch   1        Training loss: 0.0011   Val Accuracy: 79.90%
Epoch  25        Training loss: 0.0004   Val Accuracy: 82.82%
Epoch  50        Training loss: 0.0003   Val Accuracy: 87.00%
Epoch  75        Training loss: 0.0002   Val Accuracy: 88.62%
Epoch 100        Training loss: 0.0000   Val Accuracy: 90.62%
Epoch 120        Training loss: 0.0000   Val Accuracy: 90.58%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.70% at Epoch 101).

评估搜索到的子网

[13]:
# you can restore the fine-tuned model from the vanilla model
optimized_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas_finetuned.pth")

# test the accuracy
print(f"Test Accuracy of the fine-tuned pruned net: {evaluate(optimized_model, test_loader)}")
Test Accuracy of the fine-tuned pruned net: 90.28

结论?

比较可以总结如下:

模型

浮点运算次数

参数

测试准确率

ResNet20

40.6百万

268k

90.9%

FastNAS 子网

29.6百万

174k

90.3%

正如我们在这里看到的,我们已经减少了FLOPs和参数数量,这也会在几乎没有精度损失的情况下改善延迟。干得好!

下一步:查看Model Optimizer GitHub 仓库以获取更多示例。