快速入门:剪枝
提示
查看CIFAR-10 上的 ResNet20 笔记本和 HF BERT 剪枝、蒸馏与量化 以获取一个端到端的剪枝示例。
ModelOpt的剪枝库提供了许多轻量级的剪枝方法,如Minitron、FastNAS和GradNAS,这些方法可以在任何用户提供的模型上运行。查看此文档以获取有关这些剪枝方法的更多详细信息以及何时使用哪种剪枝方法的建议。
修剪预训练模型涉及三个步骤,分别是设置模型、设置搜索,最后运行搜索(修剪)。
设置您的模型
要设置您的模型进行剪枝,只需初始化模型并加载预训练的检查点。 或者,您也可以从头开始训练模型。
设置搜索
设置搜索以进行剪枝涉及设置训练和验证数据加载器,并可选择定义评分函数(FastNAS)或损失函数(GradNAS)并指定所需的剪枝约束。最常见的评分函数是模型的验证准确率,用于对从搜索空间采样的子网络进行排名。损失函数用于在模型上运行一些前向和后向传递以获取模型的梯度。
请参见下面FastNAS剪枝设置的示例。
import modelopt.torch.prune as mtp
from torchvision.models import resnet50
# User-defined model
model = resnet50()
# Load pretrained weights here
model.load_state_dict(pretrained_weights)
# Wrap your original validation function to only take the model as input.
# This function acts as the score function to rank models.
def score_func(model):
return validate(model, val_loader, ...)
# Define a dummy input with similar shape as that of your input data
dummy_input = torch.randn(1, 3, 224, 244)
# Prune the model to at most 60% of the original FLOPs
prune_constraints = {"flops": "60%"}
修剪模型
要修剪你的模型,你可以简单地调用mtp.prune
API并使用mto.save保存修剪后的模型。
下面给出了一个FastNAS剪枝的示例:
import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp
# prune_res (dict) contains state_dict / stats of the pruner/searcher.
pruned_model, prune_res = mtp.prune(
model=model,
mode="fastnas",
constraints=prune_constraints,
dummy_input=dummy_input,
config={
"data_loader": train_loader, # training data is used for calibrating BN layers
"score_func": score_func, # validation score is used to rank the subnets
# checkpoint to store the search state and resume or re-run the search with different constraint
"checkpoint": "modelopt_fastnas_search_checkpoint.pth",
},
)
# Save the pruned model.
mto.save(pruned_model, "modelopt_pruned_model.pth")
注意
剪枝后需要进行微调以恢复准确性。 请参阅剪枝微调了解更多详情。