快速入门:剪枝

提示

查看CIFAR-10 上的 ResNet20 笔记本HF BERT 剪枝、蒸馏与量化 以获取一个端到端的剪枝示例。

ModelOpt的剪枝库提供了许多轻量级的剪枝方法,如Minitron、FastNAS和GradNAS,这些方法可以在任何用户提供的模型上运行。查看此文档以获取有关这些剪枝方法的更多详细信息以及何时使用哪种剪枝方法的建议。

修剪预训练模型涉及三个步骤,分别是设置模型、设置搜索,最后运行搜索(修剪)。

设置您的模型

要设置您的模型进行剪枝,只需初始化模型并加载预训练的检查点。 或者,您也可以从头开始训练模型。

修剪模型

要修剪你的模型,你可以简单地调用mtp.prune API并使用mto.save保存修剪后的模型。

下面给出了一个FastNAS剪枝的示例:

import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp

# prune_res (dict) contains state_dict / stats of the pruner/searcher.
pruned_model, prune_res = mtp.prune(
    model=model,
    mode="fastnas",
    constraints=prune_constraints,
    dummy_input=dummy_input,
    config={
        "data_loader": train_loader,  # training data is used for calibrating BN layers
        "score_func": score_func,  # validation score is used to rank the subnets
        # checkpoint to store the search state and resume or re-run the search with different constraint
        "checkpoint": "modelopt_fastnas_search_checkpoint.pth",
    },
)

# Save the pruned model.
mto.save(pruned_model, "modelopt_pruned_model.pth")

注意

剪枝后需要进行微调以恢复准确性。 请参阅剪枝微调了解更多详情。


Next steps
  • 了解更多关于剪枝 API和支持的算法/模型。

  • 了解更多关于NAS的信息,它是剪枝的泛化。

  • 查看 ModelOpt API 文档 以获取详细的功能和使用信息。