jax.disable_jit#
- jax.disable_jit(disable=True)[源代码][源代码]#
在其动态上下文中禁用
jit()行为的上下文管理器。对于调试来说,在动态上下文中有一个禁用
jit()的机制是很有用的。请注意,这不仅会禁用用户对jit()的显式使用,还会移除 JAX 库使用的任何隐式 JIT 编译:这包括传递给更高级别原语(如scan()和while_loop())的 body 和 cond 函数的隐式 JIT 计算,jax.numpy函数实现中使用的 JIT,以及 API 实现中使用jit()的任何其他情况。但请注意,即使在 disable_jit 下,单个原语操作仍将由 XLA 编译,就像在正常的急切逐操作执行中一样。在 jitted 函数参数上具有数据依赖性的值会被跟踪和抽象化。例如,抽象值可能是一个
ShapedArray实例,表示具有给定形状和数据类型的所有可能数组的集合,但不表示具有特定值的具体数组。如果你在 jitted 函数中使用了一个良性的副作用操作,比如打印,你可能会注意到这些。>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...> [5 7 9]
这里
y已经被jit()抽象为一个ShapedArray,它表示一个具有固定形状和类型但值任意的数组。y的值也被追踪。如果我们想在调试时看到具体的值,并且避免追踪器,我们可以使用disable_jit()上下文管理器:>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]
- 参数:
disable (bool)