快速入门:蒸馏

ModelOpt的蒸馏是一组包装器和实用工具,用于在教师模型和学生模型之间轻松执行知识蒸馏。给定一个预训练的教师模型,Distillation有潜力训练一个较小的学生模型,使其比学生模型自己训练时更快和/或具有更高的准确性。

本快速入门指南展示了将Distillation集成到您的训练管道中的必要步骤。

设置您的基础模型

首先获取一个预训练模型作为教师模型,以及一个(通常较小的)模型作为学生模型。

from torchvision.models import resnet50, resnet18

# Define student
student_model = resnet18()


# Define callable which returns teacher
def teacher_factory():
    teacher_model = resnet50()
    teacher_model.load_state_dict(pretrained_weights)
    return teacher_model

设置元模型

由于知识蒸馏涉及(至少)两个模型,ModelOpt通过将学生和教师模型包装成一个元模型来简化集成过程。

请参见下面的蒸馏设置示例。此示例假设teacher_modelstudent_model的输出为logits。

import modelopt.torch.distill as mtd

distillation_config = {
    "teacher_model": teacher_factory,  # model initializer
    "criterion": mtd.LogitsDistillationLoss(),  # callable receiving student and teacher outputs, in order
    "loss_balancer": mtd.StaticLossBalancer(),  # combines multiple losses; omit if only one distillation loss used
}

distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])

teacher_model 可以是一个返回 nn.Module 的可调用对象,也可以是一个 (model_cls, args, kwargs) 的元组。 criterion 是用于学生和教师张量之间的蒸馏损失。 loss_balancer 决定了原始损失和蒸馏损失如何结合(如果需要)。

查看蒸馏获取更多信息。

训练期间的蒸馏

要从教师模型蒸馏到学生模型,只需在常规训练循环中使用元模型,同时除了原始的用户损失外,还使用元模型的.compute_kd_loss()方法来计算蒸馏损失。

下面给出了蒸馏训练的一个示例:

# Setup the data loaders. As example:
train_loader = get_train_loader()

# Define user loss function. As example:
loss_fn = get_user_loss_fn()

for input, labels in train_dataloader:
    distillation_model.zero_grad()
    # Forward through the wrapped models
    out = distillation_model(input)
    # Same loss as originally present
    loss = loss_fn(out, labels)
    # Combine distillation and user losses
    loss_total = distillation_model.compute_kd_loss(student_loss=loss)
    loss_total.backward()

注意

DataParallel 可能会破坏 ModelOpt 的蒸馏功能。请注意,HuggingFace Trainer 默认使用 DataParallel。

导出训练好的模型

模型可以轻松恢复到其原始类以供进一步使用(例如部署),而无需附加任何ModelOpt修改。

model = mtd.export(distillation_model)

Next steps
  • 了解更多关于蒸馏的信息。

  • 查看 ModelOpt 的 API 文档 以获取详细的功能和使用信息。