checkify 转换#
总结: Checkify 允许你为 JAX 代码添加可 jit 的运行时错误检查(例如越界索引)。使用 checkify.checkify 变换与类似断言的 checkify.check 函数来为 JAX 代码添加运行时检查:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
你也可以使用 checkify 来自动添加常见检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
功能化检查#
类似断言的检查API本身在功能上并不纯粹:它可以像assert一样引发Python异常作为副作用。因此,它不能与jit、pmap、pjit或scan一起分阶段执行:
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是,checkify 变换将这些效果功能化(或释放)。一个经过 checkify 变换的函数会返回一个错误 值 作为新的输出,并且保持功能上的纯度。这种功能化意味着经过 checkify 变换的函数可以按照我们的喜好与分阶段/变换组合在一起:
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
为什么JAX需要checkify?#
在某些 JAX 变换下,你可以用普通的 Python 断言来表达运行时错误检查,例如仅在使用 jax.grad 和 jax.numpy 时:
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
但普通的断言在 jit、pmap、pjit 或 scan 内部不起作用。在这些情况下,数值计算会被分阶段执行,而不是在Python执行期间急切地评估,因此数值结果不可用:
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX 变换语义依赖于函数纯度,特别是在组合多个变换时,因此我们如何在不影响所有这些的情况下提供错误机制?除了需要一个新的 API 之外,情况仍然更加复杂:XLA HLO 不支持断言或抛出错误,因此即使我们有一个能够输出断言的 JAX API,我们如何将这些断言降低到 XLA?
你可以想象手动为你的函数添加运行时检查,并导出代表错误的值:
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
错误是函数计算出的常规值,并且错误在 f_checked 外部引发。f_checked 在功能上是纯的,因此我们知道通过构造它已经可以与 jit、pmap、pjit、scan 以及 JAX 的所有变换一起工作。唯一的问题是这种管道可能会很麻烦!
checkify 为你完成这个重写:这包括在函数中传递错误值,将检查重写为布尔操作,并将结果与跟踪的错误值合并,最后将最终的错误值作为checkified函数的输出返回:
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
我们称这种操作为功能化或释放由调用check引入的效果。(在上面的“手动”示例中,错误值只是一个布尔值。checkify的错误值在概念上类似,但也跟踪错误消息并公开throw和get方法;参见jax.experimental.checkify)。checkify.check还允许你通过将它们作为格式参数提供给错误消息,将运行时值添加到你的错误消息中。
你现在可以用运行时检查手动检测你的代码,但 checkify 也可以自动为常见错误添加检查!考虑这些错误情况:
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
默认情况下,checkify 只会处理 checkify.check,并不会对上述错误进行捕捉。但如果你要求,checkify 也会自动在你的代码中添加检查。
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
选择启用哪些自动检查的API基于集合(Sets)。更多详情请参见 jax.experimental.checkify。
checkify 在 JAX 变换下。#
如上例所示,一个经过检查的函数可以愉快地进行即时编译(jit)。以下是一些使用 checkify 与其他 JAX 变换的更多示例。请注意,经过检查的函数在功能上是纯的,并且应该与所有 JAX 变换轻松组合!
jit#
你可以安全地将 jax.jit 添加到一个已检查的函数中,或者将一个已编译的函数进行检查,两者都可以正常工作。
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap/pmap#
你可以对 vmap 和 pmap 检查过的函数(或 checkify 映射的函数)进行映射。映射一个检查过的函数将给你一个映射的错误,它可以包含映射维度中每个元素的不同错误。
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
然而,一个 checkify-of-vmap 将产生一个单一的(未映射的)错误!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit#
pjit 一个检查过的函数 只是工作 ,你只需要为错误值输出指定一个额外的 out_axis_resources 为 None。
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad#
如果你使用 checkify-of-grad,你的梯度计算也会被检测。
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
注意,f 中没有乘法,但在其梯度计算中有一个乘法(这就是产生 NaN 的地方!)。因此,使用 checkify-of-grad 为前向和后向传递操作添加自动检查。
checkify.check 只会应用于你函数的原始值。如果你想对梯度值使用 check,请使用 custom_vjp:
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
jax.experimental.checkify 的优缺点#
优势#
你可以在任何地方使用它(错误只是“值”,并且在像其他值一样的转换下表现得直观)
自动检测:您不需要对代码进行本地修改。相反,
checkify可以检测所有代码!
限制#
添加大量运行时检查可能会很昂贵(例如,在每个原始数据上添加NaN检查会增加许多操作到你的计算中)
需要从函数中提取线程错误值并手动抛出错误。如果错误未被显式抛出,您可能会错过错误!
抛出一个错误值将在主机上实现该错误值,这意味着它是一个阻塞操作,这违背了 JAX 的异步提前运行。