Shortcuts

ForwardKLWithChunkedOutputLoss

class torchtune.modules.loss.ForwardKLWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

使用分块输出的前向KL,通过一次只上转换一个块来节省内存。

由于模型是用bf16训练的,在计算KL散度之前,我们必须将其转换为fp32以获得更好的准确性和稳定性。当进行转换时,内存使用量会翻倍。像llama3这样的模型具有较大的词汇量,因此输出结果也较大(bsz, num_tokens, vocab_size)。如果我们在token级别进行分块,你仍然可以正常计算交叉熵,但一次只转换一个分块可以节省大量内存。

Parameters:
  • num_output_chunks (int) – 将输出分成的块数。每个块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。 默认值:8

  • ignore_index (int) – 指定一个被忽略的目标值,该值不会对输入梯度产生影响。 损失会在非忽略的目标上进行平均。 默认值:-100

forward(student_logits: List[Tensor], teacher_logits: List[Tensor], labels: Tensor) Tensor[source]
Parameters:
  • student_logits (List[torch.Tensor]) – 来自学生模型的分块logits列表,长度为 self.num_output_chunks,其中每个块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。

  • teacher_logits (List[torch.Tensor]) – 来自教师模型的分块logits列表,长度为 self.num_output_chunks,其中每个块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。

  • labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。

Returns:

形状为 (1,) 的KL散度损失。

Return type:

torch.Tensor

示例

>>> loss_fn = ForwardKLWithChunkedOutputLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, teacher_chunks, labels)