Shortcuts

示例

torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor[source]

从概率分布中抽取的通用样本。包括对Top-K采样和温度的支持。

Parameters:
  • logits (torch.Tensor) – 从中采样的logits

  • 温度 (float) – 用于缩放预测的logits的值,默认值为1.0。

  • top_k (可选[int]) – 如果指定,我们将采样修剪为仅包含在top_k概率内的token id

  • q (可选[torch.Tensor]) – 用于softmax采样技巧的随机采样张量。如果为None,我们使用默认的softmax采样技巧。默认值为None。

示例

>>> from torchtune.generation import sample
>>> logits = torch.empty(3, 3).uniform_(0, 1)
>>> sample(logits)
tensor([[1],
        [2],
        [0]], dtype=torch.int32)
Returns:

采样的令牌ID

Return type:

torch.Tensor