jax.scipy.linalg.det

目录

jax.scipy.linalg.det#

jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[源代码][源代码]#

计算矩阵的行列式

JAX 实现的 scipy.linalg.det()

参数:
  • a (ArrayLike) – 输入数组,形状为 (..., N, N)

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回类型:

Array

返回

形状的行列式 a.shape[:-2]

参见

jax.numpy.linalg.det(): NumPy 风格的行列式 API

示例

小二维数组的行列式:

>>> x = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)

批量计算多个二维数组的行列式:

>>> x = jnp.array([[[1., 2.],
...                 [3., 4.]],
...                [[8., 5.],
...                 [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)