jax.numpy.diag#
- jax.numpy.diag(v, k=0)[源代码][源代码]#
返回指定的对角线或构造一个对角线数组。
JAX 实现的
numpy.diag()
。JAX 版本总是返回输入的一个副本,尽管如果这在 JIT 编译中使用,编译器可能会避免复制。
- 参数:
v (ArrayLike) – 输入数组。可以是1-D数组以创建对角矩阵,或者是2-D数组以提取对角线。
k (int) – 可选,默认=0。对角线偏移。正值将对角线放置在主对角线上方,负值将其放置在主对角线下方。
- 返回:
如果 v 是一个二维数组,则返回包含对角元素的一维数组。如果 v 是一个一维数组,则返回一个二维数组,其中输入元素沿着指定的对角线放置。
- 返回类型:
示例
从一维数组创建对角矩阵:
>>> jnp.diag(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
指定对角偏移:
>>> jnp.diag(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32)
从二维数组中提取对角线:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diag(x) Array([1, 5, 9], dtype=int32)