示例¶
- torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor[source]¶
从概率分布中抽取的通用样本。包括对Top-K采样和温度的支持。
- Parameters:
logits (torch.Tensor) – 从中采样的logits
温度 (float) – 用于缩放预测的logits的值,默认值为1.0。
top_k (可选[int]) – 如果指定,我们将采样修剪为仅包含在top_k概率内的token id
q (可选[torch.Tensor]) – 用于softmax采样技巧的随机采样张量。如果为None,我们使用默认的softmax采样技巧。默认值为None。
示例
>>> from torchtune.generation import sample >>> logits = torch.empty(3, 3).uniform_(0, 1) >>> sample(logits) tensor([[1], [2], [0]], dtype=torch.int32)
- Returns:
采样的令牌ID
- Return type: