jax.scipy.linalg.lu_factor

目录

jax.scipy.linalg.lu_factor#

jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[源代码][源代码]#

基于LU分解的线性求解的因式分解

JAX 实现的 scipy.linalg.lu_factor()

此函数返回一个适合与 jax.scipy.linalg.lu_solve() 一起使用的结果。对于直接的LU分解,首选 jax.scipy.linalg.lu()

参数:
  • a (ArrayLike) – 形状为 (..., M, N) 的输入数组。

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回:

一个元组 (lu, piv) - lu 是一个形状为 (..., M, N) 的数组,包含 L 在其下三角部分和 U 在其上三角部分。 - piv 是一个形状为 (..., K) 的数组,其中 K = min(M, N),用于编码枢轴。

返回类型:

tuple[Array, Array]

示例

通过LU分解求解一个小型线性系统:

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

通过 lu_factor() 计算 lu 分解,并使用它通过 lu_solve() 求解线性方程。

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

检查结果是否一致:

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)