快速入门:蒸馏
ModelOpt的蒸馏是一组包装器和实用工具,用于在教师模型和学生模型之间轻松执行知识蒸馏。给定一个预训练的教师模型,Distillation有潜力训练一个较小的学生模型,使其比学生模型自己训练时更快和/或具有更高的准确性。
本快速入门指南展示了将Distillation集成到您的训练管道中的必要步骤。
设置您的基础模型
首先获取一个预训练模型作为教师模型,以及一个(通常较小的)模型作为学生模型。
from torchvision.models import resnet50, resnet18
# Define student
student_model = resnet18()
# Define callable which returns teacher
def teacher_factory():
teacher_model = resnet50()
teacher_model.load_state_dict(pretrained_weights)
return teacher_model
设置元模型
由于知识蒸馏涉及(至少)两个模型,ModelOpt通过将学生和教师模型包装成一个元模型来简化集成过程。
请参见下面的蒸馏设置示例。此示例假设teacher_model和student_model的输出为logits。
import modelopt.torch.distill as mtd
distillation_config = {
"teacher_model": teacher_factory, # model initializer
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
}
distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
teacher_model 可以是一个返回 nn.Module 的可调用对象,也可以是一个 (model_cls, args, kwargs) 的元组。
criterion 是用于学生和教师张量之间的蒸馏损失。
loss_balancer 决定了原始损失和蒸馏损失如何结合(如果需要)。
查看蒸馏获取更多信息。
训练期间的蒸馏
要从教师模型蒸馏到学生模型,只需在常规训练循环中使用元模型,同时除了原始的用户损失外,还使用元模型的.compute_kd_loss()方法来计算蒸馏损失。
下面给出了蒸馏训练的一个示例:
# Setup the data loaders. As example:
train_loader = get_train_loader()
# Define user loss function. As example:
loss_fn = get_user_loss_fn()
for input, labels in train_dataloader:
distillation_model.zero_grad()
# Forward through the wrapped models
out = distillation_model(input)
# Same loss as originally present
loss = loss_fn(out, labels)
# Combine distillation and user losses
loss_total = distillation_model.compute_kd_loss(student_loss=loss)
loss_total.backward()
注意
DataParallel 可能会破坏 ModelOpt 的蒸馏功能。请注意,HuggingFace Trainer 默认使用 DataParallel。
导出训练好的模型
模型可以轻松恢复到其原始类以供进一步使用(例如部署),而无需附加任何ModelOpt修改。
model = mtd.export(distillation_model)