jax.experimental.multihost_utils.global_array_to_host_local_array

jax.experimental.multihost_utils.global_array_to_host_local_array#

jax.experimental.multihost_utils.global_array_to_host_local_array(global_inputs, global_mesh, pspecs)[源代码][源代码]#

将全局 jax.Array 转换为主机本地的 jax.Array

您可以使用此功能转换到 jax.Array。使用 jax.Array 与 pjit 具有与使用 GDA 与 pjit 相同的语义,即所有传递给 pjit 的 jax.Array 输入都应具有全局形状,pjit 的输出也将是具有全局形状的 jax.Array。

您可以使用此函数将全局形状的 jax.Array 输出从 pjit 转换回主机本地值,以便向 jax.Array 的过渡可以是一个机械性的更改。

示例用法:

>>> from jax.experimental import multihost_utils 
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) 
>>>
>>> with mesh: 
...   global_out = pjitted_fun(global_inputs) 
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) 
参数:
  • global_inputs (Any) – 全局 jax.Array 的 Pytree。

  • global_mesh (jax.sharding.Mesh) – 一个 jax.sharding.Mesh 对象。网格必须是连续的,这意味着主机的所有本地设备必须形成一个子立方体。

  • pspecs (Any) – 一个由 jax.sharding.PartitionSpec 对象组成的 Pytree。

返回:

主机本地数组的Pytree。