jax.numpy.tril

目录

jax.numpy.tril#

jax.numpy.tril(m, k=0)[源代码][源代码]#

返回数组的下三角部分。

JAX 实现的 numpy.tril()

参数:
  • m (ArrayLike) – 输入数组。必须满足 m.ndim >= 2

  • k (int) – k: 可选, int, 默认=0. 指定数组中元素被设置为零的次对角线。k=0 指主对角线, k<0 指主对角线下方的次对角线, k>0 指主对角线上方的次对角线。

返回:

一个与输入形状相同的数组,包含给定数组的上三角部分,其中元素在由 k 指定的次对角线以下被设置为零。

返回类型:

Array

参见

示例

>>> x = jnp.array([[1, 2, 3, 4],
...                [5, 6, 7, 8],
...                [9, 10, 11, 12]])
>>> jnp.tril(x)
Array([[ 1,  0,  0,  0],
       [ 5,  6,  0,  0],
       [ 9, 10, 11,  0]], dtype=int32)
>>> jnp.tril(x, k=1)
Array([[ 1,  2,  0,  0],
       [ 5,  6,  7,  0],
       [ 9, 10, 11, 12]], dtype=int32)
>>> jnp.tril(x, k=-1)
Array([[ 0,  0,  0,  0],
       [ 5,  0,  0,  0],
       [ 9, 10,  0,  0]], dtype=int32)

m.ndim > 2 时,jnp.tril 会对尾随轴进行批处理操作。

>>> x1 = jnp.array([[[1, 2],
...                  [3, 4]],
...                 [[5, 6],
...                  [7, 8]]])
>>> jnp.tril(x1)
Array([[[1, 0],
        [3, 4]],

       [[5, 0],
        [7, 8]]], dtype=int32)