jax.experimental.pallas.程序_ID

目录

jax.experimental.pallas.程序_ID#

jax.experimental.pallas.program_id(axis)[源代码][源代码]#

返回网格沿给定轴的内核执行位置。

例如,在核函数执行中,如果2D 网格 对应于网格坐标 (1, 2),则 program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2

参数:

axis (int) – 沿着网格的轴线计算程序。

返回类型:

jax.Array