旋转位置嵌入¶
- class torchtune.modules.RotaryPositionalEmbeddings(dim: int, max_seq_len: int = 4096, base: int = 10000)[source]¶
该类实现了在https://arxiv.org/abs/2104.09864中提出的旋转位置嵌入(RoPE)。
参考实现(用于正确性验证)可以在这里找到: https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
在这个实现中,我们通过在初始化期间计算来缓存每个位置的嵌入,直到
max_seq_len。- Parameters:
- forward(x: Tensor, *, input_pos: Optional[Tensor] = None) Tensor[source]¶
- Parameters:
x (torch.Tensor) – 输入张量,形状为
[b, s, n_h, h_d]input_pos (可选[torch.Tensor]) – 可选的张量,包含每个标记的位置ID。在训练期间,这用于指示每个标记相对于其样本的位置,形状为 [b, s]。在推理期间,这表示当前标记的位置。如果未提供,则假定标记的索引为其位置ID。默认值为 None。
- Returns:
输出张量的形状为
[b, s, n_h, h_d]- Return type:
- Notation used for tensor shapes:
b: 批量大小
s: 序列长度
n_h: 头数
h_d: 头部维度