Trilu

Trilu - 14

版本

  • 名称: Trilu (GitHub)

  • 域名: main

  • since_version: 14

  • 函数: False

  • support_level: SupportType.COMMON

  • 形状推断: True

此版本的运算符自版本14起可用。

摘要

给定一个二维矩阵或一批二维矩阵,返回张量的上三角或下三角部分。 属性“upper”决定是保留上三角部分还是下三角部分。如果设置为true, 则保留上三角矩阵。否则保留下三角矩阵。 “upper”属性的默认值为true。 Trilu接受一个形状为[*, N, M]的输入张量,其中*是零个或多个批次维度。上三角部分包括 给定对角线(k)上及其上方的元素。下三角部分包括对角线上及其下方的元素。 矩阵中的所有其他元素都设置为零。 如果k = 0,则保留主对角线上及其上方/下方的三角部分。 如果upper设置为true,则正数k保留上三角矩阵,不包括主对角线及其上方的(k-1)条对角线。 负数k值保留主对角线及其下方的|k|条对角线。 如果upper设置为false,则正数k保留下三角矩阵,包括主对角线及其上方的k条对角线。 负数k值不包括主对角线及其下方的(|k|-1)条对角线。

属性

  • upper - INT (默认为 '1'):

    布尔值。指示是否保留矩阵的上部或下部。默认为 true。

输入

在1和2个输入之间。

  • 输入 (异构) - T:

    输入张量的秩为2或更高。

  • k (可选, 异构) - tensor(int64):

    一个0维张量,包含一个与主对角线上下要排除或包含的对角线数量相对应的单一值。如果未指定,默认值为0。

输出

  • 输出 (异构) - T:

    输出张量与输入张量具有相同的类型和形状。

类型约束

  • T 在 ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ):

    将输入和输出类型限制为所有张量类型。