jax.纯回调#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[源代码][源代码]#
调用一个纯Python回调。在
jit()/vmap()/等情况下工作。更多解释,请参见 外部回调。
pure_callback允许在即时编译的 JAX 函数中调用 Python 函数。输入的callback将被传递放置在本地 CPU 上的 JAX 数组,并且它也应该返回 CPU 上的 JAX 数组。回调被视为功能上纯的,这意味着它没有副作用,并且其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如,当通过
vmap()或pmap()转换时),或者在例如 jit 装饰的函数的输出对其值没有数据依赖时根本不被调用。如果数据依赖允许,纯回调也可能被重新排序。当 vmap 被应用时,行为将取决于
vectorized关键字参数的值。当vectorized为True时,假定回调函数遵守jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])。因此,回调函数将直接在批量输入上调用(其中批量轴是前导维度)。此外,回调函数应返回具有相应前导批量轴的输出。如果未向量化,callback将按顺序映射到批量轴上。例如,如果callback = lambda x, y: np.matmul(x, y),那么我们可以自由地设置vectorized=True,因为np.matmul函数处理任意前导批量维度。- 参数:
callback (Callable[..., Any]) – 在主机上执行的函数。假定回调是一个纯函数(即没有副作用的函数):如果传递了一个非纯函数,它可能会以意想不到的方式运行,特别是在转换过程中。可调用对象将被传递数组的PyTrees作为参数,并且应该返回与``result_shape_dtypes``匹配的数组的PyTree。
result_shape_dtypes (Any) – pytree,其叶子具有
shape和dtype属性,其结构与回调函数在运行时的预期输出相匹配。jax.ShapeDtypeStruct常用于定义叶子值。*args (Any) – 传递给回调函数的参数
sharding (SingleDeviceSharding | None) – 可选的分片,指定应从哪个设备调用回调。
vectorized (bool) – 布尔值,指定回调函数是否可以以矢量化方式操作。
**kwargs (Any) – 传递给回调函数的键值参数
- 返回:
一个
jax.Array对象的 pytree,其结构与result_shape_dtypes相匹配。- 返回类型:
result
参见
jax.experimental.io_callback(): 为不纯函数设计的回调。jax.debug.callback(): 设计用于通用调试的回调。jax.debug.print(): 用于打印的回调设计。