jax.scipy.linalg.qr#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal['full', 'economic'] = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array, Array][源代码][源代码]#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array]
计算数组的QR分解
JAX 实现的
scipy.linalg.qr()。矩阵 A 的 QR 分解由以下公式给出:
\[A = QR\]其中 Q 是一个酉矩阵(即 \(Q^HQ=I\)),而 R 是一个上三角矩阵。
- 参数:
a – 形状为 (…, M, N) 的数组
mode – 计算模式。支持的值有: -
"full"``(默认):返回形状为 ``(M, M)的 Q 和形状为(M, N)的 R。 -"r":仅返回 R -"economic":返回形状为(M, K)的 Q 和形状为(K, N)的 R,其中 K = min(M, N)。pivoting – 在 JAX 中未实现。
overwrite_a – 在 JAX 中未使用
lwork – 在 JAX 中未使用
check_finite – 在 JAX 中未使用
- 返回:
一个元组
(Q, R)``(如果 ``mode不是"r")否则是一个数组R,其中:-Q是一个形状为(..., M, M)的正交矩阵(如果mode是"full")或者(..., M, K)``(如果 ``mode是"economic")。-R是一个形状为(..., M, N)的上三角矩阵(如果mode是"r"或"full")或者(..., K, N)``(如果 ``mode是"economic"),其中K = min(M, N)。
参见
jax.numpy.linalg.qr(): NumPy 风格的 QR 分解 APIjax.lax.linalg.qr(): XLA 风格的 QR 分解 API
示例
计算矩阵的QR分解:
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jax.scipy.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
检查
Q是否为正交矩阵:>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
重建输入:
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)