jax.linearize#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable][源代码][源代码]#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]
使用
jvp()和部分求值生成fun的线性近似。- 参数:
fun – 要微分的函数。其参数应为数组、标量或标准 Python 容器中的数组或标量。它应返回一个数组、标量或标准 Python 容器中的数组或标量。
primals – 雅可比矩阵
fun应在其上进行评估的原始值。应为数组元组、标量或标准 Python 容器。元组的长度等于fun的位置参数的数量。has_aux – 可选, bool. 指示
fun是否返回一个元组,其中第一个元素被认为是需要线性化的数学函数的输出,第二个元素是辅助数据。默认为 False。
- 返回:
如果
has_aux是False,返回一个元组,其中第一个元素是f(*primals)的值,第二个元素是一个函数,该函数在不重新进行线性化工作的情况下,计算在primals处评估的fun的前向模式雅可比向量积。如果has_aux是True,返回一个(primals_out, lin_fn, aux)元组,其中aux是fun返回的辅助数据。
在计算的值方面,
linearize()的行为非常类似于一个被柯里化的jvp(),这两个代码块计算相同的值:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
然而,区别在于
linearize()使用了部分求值,因此在调用f_jvp时不会重新线性化函数f。一般来说,这意味着内存使用量会随着计算规模的大小而缩放,很像在反向模式中那样。(实际上,linearize()的签名与vjp()相似!)如果你想多次应用
f_jvp,即在同一线性化点对多个不同的输入切向量进行前推评估,这个函数主要是有用的。此外,如果所有输入切向量都是一次性已知的,使用vmap()进行向量化会更高效,如下所示:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
通过像这样一起使用
vmap()和jvp(),我们避免了随着计算深度增加而存储线性化的内存成本,这是由linearize()和vjp()引起的。以下是使用
linearize()的更完整示例:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704