Shortcuts

get_unmasked_sequence_lengths

torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor[source]

返回每个批次元素的序列长度,不包括被屏蔽的标记。

Parameters:

mask (torch.Tensor) – 布尔掩码,形状为 [b x s],其中 True 表示需要被掩码的值 这通常是一个用于填充标记的掩码,其中 True 表示填充标记。

Returns:

序列索引的logits,形状为[b]

Return type:

张量

Shape notation:
  • b = 批量大小

  • s = 序列长度

示例

>>> input_ids = torch.tensor([
...        [2, 4, 0, 0],
...        [2, 4, 6, 0],
...        [2, 4, 6, 9]
...    ])
>>> mask = input_ids == 0
>>> mask
tensor([[False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
>>> get_unmasked_sequence_lengths(mask)
tensor([1, 2, 3])