剪枝
提示
查看CIFAR-10 上的 ResNet20 笔记本和 HF BERT 剪枝、蒸馏与量化 以获取一个端到端的剪枝示例。
ModelOpt 提供了三种主要的剪枝方法(也称为 mode)—— Minitron、FastNAS 和 GradNAS —— 通过统一的 API
mtp.prune。给定一个模型,
这些方法从您提供的基础模型中找到一个满足给定部署约束(例如 FLOPs、参数)的子网络,
并且几乎没有精度损失(取决于剪枝的激进程度)。
这些剪枝方法支持剪枝卷积层和线性层,以及模型的注意力头。关于这些剪枝模式的更多细节如下:
fastnas: 一种推荐用于计算机视觉模型的剪枝方法。给定一个预训练模型,FastNAS 找到在满足给定约束的同时最大化评分函数的子网络。mcore_gpt_minitron: 由NVIDIA Research开发的一种剪枝方法,用于在NVIDIA NeMo或Megatron-LM框架中使用管道并行性的GPT风格模型(例如Llama 3)。它使用激活幅度来剪枝模型的mlp、注意力头、GQA查询组、嵌入隐藏大小和层数。查看论文中算法的更多详细信息。gradnas: 一种轻量级的剪枝方法,推荐用于像Hugging Face BERT和GPT-J这样的语言模型。 它使用梯度信息来剪枝模型的线性层和注意力头,以满足给定的约束条件。
按照以下步骤使用mtp获取满足您要求的最佳模型:
训练: 只需使用现有的训练流程训练您的模型,或为您的模型加载预训练的检查点。
剪枝: 使用我们提供的
mtp.pruneAPI 对模型进行剪枝,并获取描述剪枝后网络架构的最优子网。微调: 微调生成的子网络以恢复准确性。
要了解更多关于NAS和剪枝背后的概念,请参考 NAS 概念。
培训
要进行剪枝,您可以使用通过转换预训练检查点模型获得的模型,或者从头开始训练模型。
只需初始化您的模型并在开始使用ModelOpt之前加载检查点。
你可以简单地使用现有的训练管道来训练模型,无需进一步修改。
剪枝和子网搜索
剪枝的下一步是对潜在的子网架构进行搜索,即剪枝网络,以找到满足您部署约束的最佳子网。
先决条件
要对训练好的模型进行剪枝(
mtp.prune()),你需要设置数据加载器,提供搜索constraints和一个dummy_input(用于测量你的部署约束)。你可以为
flops或params提供一个搜索约束,通过指定绝对数值(3e-6)或百分比("60%")的上限。您还应该指定您希望使用的剪枝算法(
mode)。根据模式的不同,您需要提供额外的config参数,如score_func(fastnas模式)或loss_func(gradnas模式)、dataloader、checkpoint等。最常见的评分函数是模型的验证准确率,用于对从搜索空间中采样的子网络进行排名。损失函数用于在训练数据加载器上运行一些前向和后向传递以获取梯度。请参阅
mtp.prune()的API参考以获取更多详细信息。
下面我们展示一个使用"fastnas"的示例。
执行剪枝
import modelopt.torch.prune as mtp
from torchvision.models import resnet50
# User-defined model
model = resnet50()
# Load pretrained weights here
model.load_state_dict(pretrained_weights)
# Wrap your original validation function to only take the model as input.
# This function acts as the score function to rank models.
def score_func(model):
return validate(model, val_loader, ...)
# Define a dummy input with similar shape as that of your input data
dummy_input = torch.randn(1, 3, 224, 244)
# Prune to a model with less than or equal to 60% of original FLOPs
prune_constraints = {"flops": "60%"}
# prune_res (dict) contains state_dict / stats of the pruner/searcher.
pruned_model, prune_res = mtp.prune(
model=model,
mode="fastnas",
constraints=prune_constraints,
dummy_input=dummy_input,
config={
"data_loader": train_loader, # training data is used for calibrating BN layers
"score_func": score_func, # validation score is used to rank the subnets
# checkpoint to store the search state and resume or re-run the search with different constraint
"checkpoint": "modelopt_fastnas_search_checkpoint.pth",
},
)
请注意,在剪枝过程中,我们首先将模型转换为包含不同可能网络配置的搜索空间,然后搜索最佳配置。
提示
如果评分函数的运行时间超过几分钟,考虑对评分函数中使用的数据集进行子采样。可以使用 torch.utils.data.Subset 对PyTorch数据集进行子采样,如下所示:
subset_dataset = torch.utils.data.Subset(dataset, indices)
注意
剪枝将就地修改模型。
注意
mtp.prune() 支持通过 PyTorch 中的 DistributedDataParallel 实现分布式数据并行。
目前,API 不支持修剪 PyTorch 完全分片数据并行(FSDP)模型,因此您需要在 CPU 上运行修剪,然后使用 FSDP 进行微调。请注意,GradNAS 比 FastNAS 快得多(因此在 CPU 上也是可行的),并且推荐用于像 BERT、GPT-J 6B 和其他 LLMs 这样的语言模型。
存储修剪结果
为了存储修剪后的模型以供将来使用,您可以使用
mto.save():
import modelopt.torch.opt as mto
mto.save(pruned_model, "modelopt_pruned_model.pth")
注意
请参阅保存和恢复ModelOpt修改的模型以了解所有可用的保存和恢复选项。
自定义剪枝配置
在上面的例子中,我们使用了mtp.prune()的默认模式配置。你可以通过mtp.config.FastNASConfig()查看它。
你也可以指定自定义的模式配置以获得不同的搜索空间。更多信息请参阅mtp.prune()文档。下面显示了一个示例配置:
import modelopt.torch.prune as mtp
# config to restrict the search space to have a Conv2d out channels as multiple of 64
ss_config = mtp.config.FastNASConfig()
ss_config["nn.Conv2d"]["*"]["channel_divisor"] = 64
# run pruning as shown above
mtp.prune(model, mode=[("fastnas", ss_config)], ...)
分析搜索空间并选择约束条件
搜索空间描述了潜在剪枝子网的候选者。您可以在mtp.prune() API中获取有关搜索空间整体统计信息的信息。在剪枝过程开始之前,将打印以下信息:
Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------
Profiling Results
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops │ 274.34M │ 1.28G │ 4.59G │ 16.73 │
│ params │ 2.70M │ 9.75M │ 25.50M │ 9.43 │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
Constraints Evaluation
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ ┃ ┃ Satisfiable ┃
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops │ 2.75G │ True │
└──────────────┴──────────────┴──────────────┘
Search Space Summary:
----------------------------------------------------------------------------------------------------
* conv1.out_channels [32, 64]
conv1.in_channels [3]
bn1.num_features [32, 64]
* layer1.0.conv1.out_channels [32, 64]
layer1.0.conv1.in_channels [32, 64]
layer1.0.bn1.num_features [32, 64]
* layer1.0.conv2.out_channels [32, 64]
...
...
...
* layer4.2.conv1.out_channels [32, 64, 96, 128, ..., 416, 448, 480, 512]
layer4.2.conv1.in_channels [2048]
layer4.2.bn1.num_features [32, 64, 96, 128, ..., 416, 448, 480, 512]
* layer4.2.conv2.out_channels [32, 64, 96, 128, ..., 416, 448, 480, 512]
layer4.2.conv2.in_channels [32, 64, 96, 128, ..., 416, 448, 480, 512]
layer4.2.bn2.num_features [32, 64, 96, 128, ..., 416, 448, 480, 512]
layer4.2.conv3.out_channels [2048]
layer4.2.conv3.in_channels [32, 64, 96, 128, ..., 416, 448, 480, 512]
----------------------------------------------------------------------------------------------------
Number of configurable hparams: 36
Total size of the search space: 2.48e+28
Note: all constraints can be satisfied within the search space!
分析结果将帮助您理解搜索空间,并提出一个可以迭代的潜在搜索约束。
提示
通常,最大/最小比率大于3的搜索空间是一个良好的搜索空间,具有许多可剪枝的层。 最大/最小比率越高,模型的可剪枝性越强(可能会使FastNAS变慢但效果更好)。
搜索约束的一个良好起点是搜索空间的重心。如果您使用的是flops/params约束,我们强烈建议您首先获取一个剪枝模型,并在目标部署上测量其延迟,然后再微调剪枝模型。根据延迟情况,您可以相应地调整约束。一旦您有一个符合延迟约束的模型,您可以对其进行微调以恢复准确性。如果您无法恢复准确性(可能是因为剪枝过于激进),您可以尝试增加约束并重复此过程。
注意
如果在搜索空间内无法满足约束条件,剪枝将被中断并会引发错误。
微调
架构搜索的最后一步是在您的数据集上微调修剪后的模型。这样您可以确保为修剪后的模型获得最佳性能。
先决条件
要进行微调,您需要一个修剪过的子网,如前一节所述。
您可以重用现有的训练管道。我们建议使用原始的训练计划进行微调:
1x 训练周期(或 1x 下游任务微调),
相同或更小(0.5x-1x)的学习率。
加载修剪后的模型
你可以简单地使用mto.restore()来恢复你修剪过的模型(权重和架构):
import modelopt.torch.opt as mto
from torchvision.models import resnet50
# Build original model
model = resnet50()
# Restore the pruned architecture and weights
pruned_model = mto.restore(model, "modelopt_pruned_model.pth")
运行微调
现在,请继续使用您标准的训练流程和预配置的超参数对剪枝后的子网进行微调。通常一个好的微调计划是重复预训练计划,初始学习率为0.5x-1x。
不要忘记使用mto.save()保存模型。
train(pruned_model)
mto.save(pruned_model, "modelopt_pruned_finetuned_model.pth")
部署
经过修剪和微调的模型现在已准备好用于部署等下游任务。你现在手中的模型应该是满足部署感知搜索约束的最佳神经网络。
import modelopt.torch.opt as mto
from torchvision.models import resnet50
# Build original model
model = resnet50()
model = mto.restore(model, "modelopt_pruned_finetuned_model.pth")
# Continue with downstream tasks like deployment (e.g. TensorRT or TensorRT-LLM)
...
剪枝概念
剪枝是从神经网络中移除冗余组件的过程,针对特定任务。从概念上讲,剪枝与NAS(神经架构搜索)类似,但与NAS相比,剪枝的计算开销较小,但代价是可能找到的架构不如NAS找到的最优架构。大多数API基于相应的NAS API,但进行了调整以反映更简单的工作流程。
具体来说,对于剪枝,我们并不专门训练搜索空间及其所有子网。 相反,我们使用预训练的检查点来近似搜索空间。因此,我们可以跳过 (可能昂贵的)搜索空间训练步骤,直接 搜索子网架构,然后再对 生成的子网进行微调。
注意
如果你想了解更多关于NAS和剪枝背后的概念,请查看 NAS概念,包括NAS和剪枝之间更详细的比较。