jax.从进程本地数据创建数组#
- jax.make_array_from_process_local_data(sharding, local_data, global_shape=None)[源代码][源代码]#
使用进程中可用的数据创建分布式张量。
此函数是 make_array_from_callback 的一个常见特殊情况。它假设数据在进程中可用,并负责处理索引。
最常见的情况是分片在批次维度上进行,每个主机只加载其对应的子批次。此函数也支持更一般的情况,例如混合的多主机和多轴复制与分片,但您需要正确计算进程本地数据的大小和内容,以满足分片约束。
特别是,如果任何两台主机是副本,host_local_data 也应该相同。
global_shape 是可选的。如果没有提供,它将根据 local_data 和分片推断出来,假设每个主机只代表他们自己的数据用于均匀分片。如果分片是非均匀的(见下面的注释),将引发异常。
显式设置 global_shape 可以实现更精细的控制,并且适用于非均匀分片。global_shape 的每个维度必须与 host_local_data 匹配,或者与分片推断的全局形状匹配(在这种情况下,它相当于设置为 None,但更为明确)。
例如,如果维度 i 是完全分片的,那么这个大小将是 per_device_shape[i] * jax.local_device_count()。每个设备将被映射到 local_data 数组的本地切片中。例如,如果给定的进程地址切片是 (8, 12) 和 (24, 28),那么这些切片将被映射到 local_data 的 (0, 4) 和 (4, 8) 中。
对于每个全局形状与局部形状匹配的维度,每个设备将在 local_data 中查找切片。例如,如果 global_shape == local_data.shape,则假定局部数据是将被分片到设备的实际目标数组。
如果 global_shape 与 local_data.shape 相同,那么所有主机上的数据必须相同。
示例
>>> from jax.sharding import PartitionSpec as P >>> mesh_rows = 2 >>> mesh_cols = jax.device_count() // 2 ... >>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),)) >>> rows_per_device = 2 >>> feature_length = 32 >>> per_device_shape = (rows_per_device, feature_length) >>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length) >>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape) >>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays >>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:] >>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape) ... >>> assert output_global_array.addressable_data(0).shape == per_device_shape >>> assert output_global_array.shape == global_shape
注意:虽然大多数分片是均匀的,但可以设计一种异域分片网格,其中每个进程的设备在某些维度上以非网格状模式排列,或者索引以非平凡的方式重叠。这种分片在这些维度上被称为“非均匀”。在这种情况下,这些方向上的全局形状必须与局部形状匹配,因为没有有意义的方式以非重叠的方式表示所有需要的每进程数据。例如,对于全局形状 4x4,如果分片看起来像这样:
0123 2103 4675 4567
使用4个进程,分别包含设备 (0,1), (2, 3), (4, 5), (6, 7)。然后每个主机的数据看起来像
xx.. ..xx …. …. .xx. x..x …. …. …. …. x..x .xx. …. …. xx.. ..xx
分片在行上是均匀的(每个主机需要行1-2或行3-4),而在列上是非均匀的(主机需要重叠但不匹配的列集)。因此,本地数据必须具有2x4或4x4的形状,即使每个主机可能适合2x2的形状。在这种情况下,用户必须明确提供global_shape,并且对于local_shape=(2, 4),可能有效的全局形状是(2, 4)和(4, 4)。
另一方面,对于分片:
0213 x.x. .x.x. …. …. 0213 x.x. .x.x. …. …. 4657 …. …. .x.x x.x. 4657 …. …. .x.x x.x.
对于 local_shape=(2, 2),此函数可以接受 2x2、2x4、4x2 和 4x4 的全局形状选择。将 global_shape 设置为 None,在这种情况下等同于将其设置为 (4, 4)。
- 参数:
sharding (Sharding) – 全局张量的分片。
local_data (np.ndarray) – 主机上的数据应放置在本地设备上。每个维度应与 global_shape 匹配,或与 num_addressable_indices(dim) 匹配。
global_shape (Shape | None) – 全局张量的目标形状。如果为 None,将从 local_data 和 sharding 推断。
- 返回:
将具有 sharding=sharding 和形状为 global_shape 的张量。
- 返回类型:
ArrayImpl