蒸馏

介绍

ModelOpt的蒸馏API(modelopt.torch.distill)允许您通过最少的脚本修改启用知识蒸馏训练管道。

按照以下步骤操作,使用modelopt.torch.distill获取通过从更强大的教师模型直接转移知识训练的模型:

  1. 通过 mtd.convert 转换您的模型: 将教师模型和学生模型包装成一个更大的元模型,该元模型抽象了两者之间的交互。

  2. 蒸馏训练:无缝使用元模型代替原始模型,并只需添加一行代码来计算损失即可运行原始脚本。

  3. 检查点并重新加载:通过mto.save保存模型,并通过mto.restore恢复。查看保存和恢复以了解更多信息。

要了解更多关于蒸馏及相关概念的信息,请参考以下部分 蒸馏概念

转换和集成

你可以使用mtd.convert()将你的模型转换为DistillationModel

示例用法:

import modelopt.torch.distill as mtd
from torchvision.models import resnet50

# User-defined model (student)
model = resnet50()

# Configure and convert for distillation
distillation_config = {
    # `teacher_model` is a model class or callable, or a tuple.
    # If a tuple, it must be of the form (model_cls_or_callable,) or
    # (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
    "teacher_model": teacher_model,
    "criterion": mtd.LogitsDistillationLoss(),
    "loss_balancer": mtd.StaticLossBalancer(),
}
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])

# Export model in original class form
model_exported = mtd.export(distillation_model)

注意

配置需要一个(非lambda)可调用对象来返回一个教师模型,而不是模型本身。这是为了避免在保存蒸馏元模型时重新保存教师状态字典。因此,在通过mto.restore工具恢复时,命名空间中必须提供相同的可调用对象。

注意

由于模型不再属于同一类,转换后在模型上调用type()将不会按预期工作。 尽管isinstance()仍然有效,因为模型动态地成为原始类的子类。

蒸馏概念

下面,我们将概述ModelOpt的蒸馏功能及其基本概念和术语。

概述

Glossary

知识蒸馏

将可学习的特征信息从教师模型传递到学生模型。

Student

要训练的模型(可以从头开始或使用预训练模型)。

教师

固定的、预训练的模型用作学生将“学习”的示例。

蒸馏损失

用于在学生和教师的特征之间执行知识蒸馏的损失函数,与学生的原始任务损失分开。

损失平衡器

一个实用程序的实现,用于确定如何将蒸馏损失和原始学生任务损失结合成一个单一的标量。

Soft-label Distillation

在教师模型和学生模型的输出logits之间执行知识蒸馏的具体过程。

概念

知识蒸馏

蒸馏可以是一个更广泛的术语,用于定义模型之间压缩的任何类型的信息, 但在这种情况下,我们指的是基本的师生知识蒸馏。该过程在已经训练好的模型(教师) 和未训练的模型(学生)之间创建一个辅助损失(或可以替代原始损失), 希望学生能够学习教师已经掌握的信息(即特征图或逻辑)。这可以服务于多种目的:

A. 模型大小缩减:一个更小、更高效的学生模型(可能是经过剪枝的教师模型)达到或超过更大、更慢的教师模型的准确率。(参见彩票假设了解其背后的原因,这也适用于剪枝)

B. 纯训练的替代方案:从现有模型中蒸馏出一个模型(然后进行微调)通常比从头开始训练更快。

C. 模块替换:可以在模型中将单个模块替换为更高效的模块,并对其原始输出使用蒸馏方法,以有效地将其重新整合到整个模型中。

学生

这是我们最终希望训练和使用的模型。它理想地满足了所需的架构和计算要求,但要么未经训练,要么需要提高准确性。

教师

这是从中学习到的特征/信息用于创建学生损失的模型。 通常它比期望的要大和/或慢,但具有令人满意的准确性。

蒸馏损失

为了真正“传递”知识从教师到学生,我们需要在学生原有的损失函数中添加(或替换)一个优化目标。这可以简单到在教师和学生之间对两个相同大小的激活张量实施均方误差(MSE),假设教师学习到的特征是高质量的,应尽可能模仿。

ModelOpt 支持为每个层输出对指定不同的损失函数,并包含一些预定义的函数供使用,尽管用户通常需要定义自己的函数。模块对到损失函数的映射通过配置字典的 criterion 键指定 - 分别是学生和教师 - 并且损失函数本身也应接受相同顺序的输出:

# Example using pairwise-mapped criterion.
# Will perform the loss on the output of ``student_model.classifier`` and ``teacher_model.layers.18``
distillation_config = {
    "teacher_model": teacher_model,
    "criterion": {("classifier", "layers.18"): mtd.LogitsDistillationLoss()},
}
distillation_model = atd.convert(student_model, mode=[("kd_loss", distillation_config)])

损失的中间输出由 DistillationModel捕获,然后使用 DistillationModel.compute_kd_loss()调用损失。 如果存在,原始学生的非蒸馏损失将作为参数传递。

编写自定义损失函数通常是必要的,特别是为了处理需要处理以获得logits和激活的输出。

损失平衡器

由于蒸馏损失可能应用于多对层,损失以字典的形式返回,应将其减少为标量值以进行反向传播。损失平衡器(其接口由DistillationLossBalancer定义)用于实现此目的。

如果蒸馏损失仅应用于单层输出对,并且没有学生损失可用,则不应提供损失平衡器。

ModelOpt 提供了一个简单的 Balancer 实现,上述接口可用于创建自定义的 Balancer。

软标签蒸馏

仅对学生/教师分类模型的输出logits进行蒸馏的场景被称为软标签蒸馏。在这种情况下,如果教师的输出完全优于任何真实标签,甚至可以完全省略学生的原始分类损失。