• Docs >
  • Distilling Llama3.1 8B into Llama3.2 1B using Knowledge Distillation
Shortcuts

使用知识蒸馏将Llama3.1 8B提炼为Llama3.2 1B

本指南将教你关于知识蒸馏(KD)的知识,并展示如何使用torchtune将Llama3.1 8B模型蒸馏为Llama3.2 1B模型。 如果你已经了解知识蒸馏是什么,并想直接开始在torchtune中运行你自己的蒸馏, 你可以跳转到torchtune中的KD配方教程。

What you will learn
  • 什么是KD以及它如何帮助提高模型性能

  • torchtune 中 KD 组件的概述

  • 如何使用torchtune从教师模型蒸馏到学生模型

  • 如何尝试不同的KD配置

Prerequisites

什么是知识蒸馏?

Knowledge Distillation 是一种广泛使用的压缩技术,它将知识从较大的(教师)模型转移到较小的(学生)模型。较大的模型具有更多的参数和知识容量,然而,这种较大的容量在部署时也更具计算成本。知识蒸馏可以用于将较大模型的知识压缩到较小的模型中。其思想是通过学习较大模型的输出,可以提高较小模型的性能。

知识蒸馏是如何工作的?

知识通过在一个转移集上训练学生模型来从教师模型传递给学生模型,其中学生被训练以模仿教师的标记级概率分布。下图是知识蒸馏(KD)工作原理的简化表示。

../_images/kd-simplified.png

总损失可以通过多种方式配置。torchtune 中的默认 KD 配置将交叉熵(CE)损失与前向 Kullback-Leibler (KL) 散度损失结合使用,这是标准 KD 方法中使用的。前向 KL 散度旨在通过强制学生的分布与教师的所有分布对齐来最小化差异。然而,将学生分布与整个教师分布对齐可能并不有效,并且有多篇论文,如 MiniLLMDistiLLMGeneralized KD,引入了新的 KD 损失来解决这些限制。在本教程中,让我们来看看前向 KL 散度损失的实现。

import torch
import torch.nn.functional as F

class ForwardKLLoss(torch.nn.Module):
  def __init__(self, ignore_index: int = -100)
    super().__init__()
    self.ignore_index = ignore_index

  def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
    # Implementation from https://github.com/jongwooko/distillm
    # Computes the softmax of the teacher logits
    teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    # Computes the student log softmax probabilities
    student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    # Computes the forward KL divergence
    prod_probs = teacher_prob * student_logprob
    # Compute the sum
    x = torch.sum(prod_probs, dim=-1).view(-1)
    # We don't want to include the ignore labels in the average
    mask = (labels != self.ignore_index).int()
    # Loss is averaged over non-ignored targets
    return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

为了简化计算,这里省略了一些细节,但如果您想了解更多,可以在ForwardKLLoss中查看实现。默认情况下,KD配置使用ForwardKLWithChunkedOutputLoss来减少内存。当前的实现仅支持具有相同输出logit形状和相同分词器的学生和教师模型。

torchtune中的KD配方

使用torchtune,我们可以轻松地将知识蒸馏应用于Llama3以及其他LLM模型系列。 让我们来看看如何使用torchtune的KD配方来蒸馏模型。

首先,确保您已下载所有模型权重。对于此示例,我们将使用Llama3.1-8B作为教师模型,Llama3.2-1B作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

然后,我们将使用LoRA对教师模型进行微调。根据我们的实验和之前的工作,我们发现当教师模型已经在目标数据集上进行了微调时,知识蒸馏(KD)表现更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以在单个GPU上运行以下命令,将微调的8B模型蒸馏为1B模型。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在前面的例子中,我们使用了LoRA微调的8B教师模型和基线1B学生模型, 但我们可能希望尝试不同的配置和超参数。 在本教程中,我们将在alpaca_cleaned_dataset上进行微调, 并通过EleutherAI的LM评估工具truthfulqa_mc2hellaswagcommonsense_qa任务上评估模型。 让我们来看看以下因素的影响:

  1. 使用微调的教师模型

  2. 使用微调的学生模型

  3. kd_ratio和学习率的超参数调优

  4. 教师和学生模型的参数数量更接近

使用微调的教师模型

配置中的默认设置使用了微调的教师模型。现在,让我们先看看不先微调教师模型的效果。要更改教师模型,您可以修改配置中的teacher_checkpointer

teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
      model-00001-of-00004.safetensors,
      model-00002-of-00004.safetensors,
      model-00003-of-00004.safetensors,
      model-00004-of-00004.safetensors
  ]

在下表中,我们可以看到,1B模型的标准微调比基线1B模型实现了更好的准确性。通过使用微调的8B教师模型,我们看到truthfulqa的结果相当,hellaswag和commonsense有所改进。当使用基线8B作为教师时,我们看到所有指标都有所改进,但低于其他配置。

../_images/kd-finetune-teacher.png

观察损失情况,使用基线8B作为教师模型导致的损失高于使用微调后的教师模型。KD损失也保持相对稳定,这表明教师模型应与迁移数据集具有相同的分布。

使用微调的学生模型

对于这些实验,让我们看看当学生模型已经微调时,KD的效果。在这些实验中,我们观察了基线模型和微调的8B和1B模型的不同组合。要更改学生模型,您可以首先微调1B模型,然后在配置中修改学生模型的检查点:

checkpointer:
   _component_: torchtune.training.FullModelHFCheckpointer
   checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
   checkpoint_files: [
     hf_model_0001_0.pt
   ]

使用微调后的学生模型进一步提高了truthfulqa的准确性,但hellaswag和commonsense的准确性有所下降。使用微调的教师模型和基线学生模型在hellaswag和commonsense数据集上取得了最佳结果。基于这些发现,最佳配置将根据您优化的评估数据集和指标而变化。

../_images/kd-finetune-student.png

根据损失图,使用微调的教师模型会导致较低的损失,无论学生模型是否经过微调。同样有趣的是,当使用微调的学生模型时,类别损失开始增加。

超参数调优:学习率

默认情况下,配置中的学习率为\(3e^{-4}\),与LoRA配置相同。在这些实验中,我们将学习率从高达\(1e^{-3}\)调整到低至\(1e^{-5}\)。要更改学习率,您可以简单地覆盖学习率参数,使用:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

根据结果,最佳学习率会根据您优化的指标而变化。

../_images/kd-hyperparam-lr.png

根据损失图,除了\(1e^{-5}\)之外,所有学习率都导致相似的损失,而\(1e^{-5}\)具有更高的KD和类别损失。

超参数调优:KD 比率

在配置中,我们将kd_ratio设置为0.5,这为类别损失和KD损失赋予了相同的权重。在这些实验中,我们观察了不同KD比率的影响,其中0仅使用类别损失,1仅使用KD损失。类似于改变学习率,KD比率可以通过以下方式调整:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25

总体而言,评估结果在较高的KD比率下略好。

../_images/kd-hyperparam-kd-ratio.png

Qwen2 1.5B 到 0.5B

KD配方也可以应用于不同的模型系列。在这里,我们研究了当教师模型和学生模型之间的参数数量更接近时KD的效果。在这个实验中,我们使用了Qwen2 1.5B和Qwen2 0.5B,其配置可以在qwen2/knowledge_distillation_single_device配置中找到。在这里,我们看到在alpaca清理数据集上的训练仅提高了truthful_qa的性能,并降低了其他评估任务的指标。对于truthful_qa,KD将学生模型的性能提高了5.8%,而微调将性能提高了1.3%。

../_images/kd-qwen2-res.png