Shortcuts

CEWithChunkedOutputLoss

class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

使用分块输出的交叉熵,通过一次只提升一个块来节省内存。

每当模型使用bf16进行训练时,在运行CE之前,我们必须将其转换为fp32以提高准确性和稳定性。当进行转换时,内存使用量会翻倍。像llama3这样的模型具有较大的词汇量,因此具有形状为(bsz, num_tokens, vocab_size)的较大输出张量。如果我们在令牌级别进行分块,您仍然可以正常计算交叉熵,但一次只转换一个分块可以节省大量内存。

为了获得更好的性能,CE和上转型必须一起编译。 使用此类时,我们建议仅在方法compute_cross_entropy上使用torch.compile()。 如果编译整个类,分块带来的收益将无法实现。

更多详情,请参考:https://github.com/pytorch/torchtune/pull/1390

compute_cross_entropy(logits: Tensor, labels: Tensor, normalize: bool = True) Tensor[source]

将logits上转换为fp32并计算交叉熵损失。

forward(logits: List[Tensor], labels: Tensor) Tensor[source]
Parameters:
  • logits (列表[torch.Tensor]) – 长度为self.num_output_chunks的分块logits列表,其中每个块的形状为(batch_size, num_tokens / num_output_chunks, vocab_size)

  • labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。

Returns:

形状为 (1,) 的交叉熵损失。

Return type:

torch.Tensor

示例

>>> loss_fn = ChunkedCrossEntropyLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>>
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, labels)