Triton 语义

Triton 主要遵循 NumPy 的语义,只有少量例外。本文档将介绍 Triton 支持的部分数组计算特性,并说明 Triton 与 NumPy 语义存在差异的地方。

类型提升

类型提升发生在使用不同数据类型的张量进行运算时。对于与dunder方法相关的二元运算以及三元函数tl.where的最后两个参数,Triton会自动按照类型层次结构(数据类型集合)将输入张量转换为共同的数据类型:{bool} < {integral dypes} < {floating point dtypes}

算法如下:

  1. 类型 如果一个张量的数据类型属于更高阶的类型,另一个张量会被提升为该类型:(int32, bfloat16) -> bfloat16

  2. 宽度 如果两个张量的数据类型属于同一类别,且其中一个具有更高的位宽,则另一个会被提升为该数据类型:(float32, float16) -> float32

  3. 优先使用float16 如果两个张量具有相同的位宽和符号属性但数据类型不同(float16bfloat16或不同的fp8类型),它们都会被提升为float16(float16, bfloat16) -> float16

  4. 优先使用无符号类型 否则(相同位宽但符号不同时),它们会被提升为无符号数据类型:(int32, uint32) -> uint32

当涉及标量时,规则会有些不同。这里的标量指的是数字字面量、标记为tl.constexpr的变量或这些的组合。它们由NumPy标量表示,类型为boolintfloat

当一个操作涉及张量和标量时:

  1. 如果标量的类型低于或等于张量,它将不参与类型提升:(uint8, int) -> uint8

  2. 如果标量属于更高类型,我们会在整数类型 int32 < uint32 < int64 < uint64 和浮点类型 float32 < float64 中选择能容纳该标量的最低精度数据类型。然后,张量和标量都会被提升到这个数据类型:(int16, 4.0) -> float32

广播

广播(Broadcasting)允许对不同形状的张量进行操作,通过自动扩展它们的形状到兼容的尺寸而不需要复制数据。这遵循以下规则:

  1. 如果其中一个张量形状较短,则在左侧填充1,直到两个张量具有相同的维度数:((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))

  2. 如果两个维度相等,或者其中一个维度为1,则这两个维度是兼容的。维度1将被扩展以匹配另一个张量的维度。((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))

与NumPy的差异

C语言整数除法中的舍入规则 Triton中的运算符遵循C语言语义而非Python语义以提高效率。因此,int // int对于混合符号的整数实现了向零舍入的C语言规则,而不是像Python那样向负无穷舍入。出于同样的原因,取模运算符int % int(定义为a % b = a - b * (a // b))也遵循C语言语义而非Python语义。

可能令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循Python语义。