jax.experimental.checkify.checkify

目录

jax.experimental.checkify.checkify#

jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[源代码][源代码]#

fun 中功能化 check 调用,并可选择添加运行时错误检查。

运行时错误要么是用户添加的 check() 断言,要么是根据 errors 参数自动添加的检查,如 NaN 检查。

返回的函数将返回一个错误对象 err 以及原始函数的输出。err.get() 将返回 None``(如果没有发生错误)或包含错误信息的字符串。此错误信息将对应于发生的第一个错误。``err.throw() 如果发生错误,将引发带有错误信息的 ValueError。

默认情况下,仅启用用户添加的 check() 断言。您可以通过 errors 参数启用自动检查。

可以启用的自动检查集,以及何时生成错误:
  • user_checks: 一个 check() 评估为 False。

  • nan_checks: 浮点运算生成了一个 NaN 值作为输出。

  • div_checks: 除以零。

  • index_checks: 索引超出边界。

可以通过传递一个错误 集合`(例如 ``errors=nan_checks`)来同时启用多个类别。多个集合可以重新组合(例如 errors=float_checks|user_checks

参数:
  • fun – 可调用对象,可以包含用户检查(参见 check())。

  • errors (frozenset[ErrorCategory]) – 一组 ErrorCategory 值,定义了启用的检查集合。默认情况下仅启用显式 checks``(``user_checks)。例如,您也可以通过传递 float_checks 集合来启用 NAN 和 DIV 错误,或者通过集合操作(例如 float_checks | user_checks)组合多个集合。

  • f (Callable[..., Out])

返回:

一个接受与 fun 相同参数的函数,并返回一个包含两个元素的输出:第一个元素是一个 Error 值,表示第一个失败的 check(),第二个元素是 fun 的原始输出。

返回类型:

Callable[…, tuple[Error, Out]]

例如:

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>>
>>> @jax.jit
... def f(x):
...   y = jnp.sin(x)
...   return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin