mlx.nn.正弦位置编码

mlx.nn.正弦位置编码#

正弦位置编码(维度 int, 最小频率 浮点数 = 0.0001, 最大频率 浮点数 = 1, 缩放 浮点数 | = , 余弦初值 布尔值 = , 完整转数 布尔值 = )#

实现正弦位置编码。

更多详情请参阅论文Attention Is All You Need

Parameters
  • dims (int) – 结果位置嵌入的维度。

  • min_freq (float, 可选) – 预期的最小频率。默认值: 0.0001

  • max_freq (float, 可选) – 预期的最大频率。默认值: 1

  • scale (float, 可选) – 嵌入的乘法比例。 默认值: sqrt(2/dims).

  • cos_first (bool, 可选) – 如果 True 使用 [cos(x); sin(x)] 进行嵌入 而不是相反的顺序。默认值:False

  • full_turns (bool, 可选) – 如果 True,则将频率乘以 \(2\pi\)。默认值:False

方法