jax.lax.ppermute

目录

jax.lax.ppermute#

jax.lax.ppermute(x, axis_name, perm)[源代码][源代码]#

根据排列 perm 执行集体排列。

如果 x 是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。

此函数是 CollectivePermute HLO 的模拟。

参数:
  • x – 带有映射轴名为 axis_name 的数组。

  • axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅 jax.pmap() 文档)。

  • perm – 整数对的列表,表示 (源索引, 目标索引) 对,这些对编码了名为 axis_name 的映射轴应如何打乱。整数值被视为映射轴 axis_name 的索引。任何两对不应具有相同的源索引或相同的目标索引。对于轴 axis_name 的每个索引,如果在 perm 中没有对应的目标索引,则结果中相应的值将填充适当类型的零。

返回:

x 形状相同的数组(s),沿着轴 axis_namex 中收集切片,根据排列 perm