generate_next_token¶
- torchtune.generation.generate_next_token(model: TransformerDecoder, input_pos: Tensor, x: Tensor, q: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None) Tuple[Tensor, Tensor][source]¶
根据提示生成下一个标记,并返回相应的对数概率。
- Parameters:
model (TransformerDecoder) – 用于生成的模型
input_pos (torch.Tensor) – 包含与给定提示相关联的位置编码的张量,形状为 [bsz x seq_length]。
x (torch.Tensor) – 包含与给定提示相关联的token ID的张量,形状为[bsz x seq_length]。
q (可选[torch.Tensor]) – 用于softmax采样技巧的随机采样张量。 参见 https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (可选[torch.Tensor]) – 注意力掩码,形状为 [bsz x seq_length x seq_length],默认值为 None。
温度 (float) – 用于缩放预测的logits的值,默认值为1.0。
top_k (可选[int]) – 用于采样的Top-k值,默认为None。
- Returns:
- 两个张量的元组:
- tokens (torch.Tensor): 包含生成标记的张量,
形状为 [bsz x 1]。
- logits (torch.Tensor): 包含与生成标记相关的logits的张量,
形状为 [bsz x seq_length x vocab_size]。
- Return type:
元组[torch.Tensor, torch.Tensor]