get_unmasked_sequence_lengths¶
- torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor[source]¶
返回每个批次元素的序列长度,不包括被屏蔽的标记。
- Parameters:
mask (torch.Tensor) – 布尔掩码,形状为 [b x s],其中 True 表示需要被掩码的值 这通常是一个用于填充标记的掩码,其中 True 表示填充标记。
- Returns:
序列索引的logits,形状为[b]
- Return type:
张量
- Shape notation:
b = 批量大小
s = 序列长度
示例
>>> input_ids = torch.tensor([ ... [2, 4, 0, 0], ... [2, 4, 6, 0], ... [2, 4, 6, 9] ... ]) >>> mask = input_ids == 0 >>> mask tensor([[False, False, True, True], [False, False, False, True], [False, False, False, False]]) >>> get_unmasked_sequence_lengths(mask) tensor([1, 2, 3])