Triton 语义¶
Triton 主要遵循 NumPy 的语义,只有少量例外。本文档将介绍 Triton 支持的部分数组计算特性,并说明 Triton 与 NumPy 语义存在差异的地方。
类型提升¶
类型提升发生在使用不同数据类型的张量进行运算时。对于与dunder方法相关的二元运算以及三元函数tl.where
的最后两个参数,Triton会自动按照类型层次结构(数据类型集合)将输入张量转换为共同的数据类型:{bool} < {integral dypes} < {floating point dtypes}
。
算法如下:
类型 如果一个张量的数据类型属于更高阶的类型,另一个张量会被提升为该类型:
(int32, bfloat16) -> bfloat16
宽度 如果两个张量的数据类型属于同一类别,且其中一个具有更高的位宽,则另一个会被提升为该数据类型:
(float32, float16) -> float32
优先使用float16 如果两个张量具有相同的位宽和符号属性但数据类型不同(
float16
和bfloat16
或不同的fp8
类型),它们都会被提升为float16
。(float16, bfloat16) -> float16
优先使用无符号类型 否则(相同位宽但符号不同时),它们会被提升为无符号数据类型:
(int32, uint32) -> uint32
当涉及标量时,规则会有些不同。这里的标量指的是数字字面量、标记为tl.constexpr的变量或这些的组合。它们由NumPy标量表示,类型为bool
、int
和float
。
当一个操作涉及张量和标量时:
如果标量的类型低于或等于张量,它将不参与类型提升:
(uint8, int) -> uint8
如果标量属于更高类型,我们会在整数类型
int32
<uint32
<int64
<uint64
和浮点类型float32
<float64
中选择能容纳该标量的最低精度数据类型。然后,张量和标量都会被提升到这个数据类型:(int16, 4.0) -> float32
广播¶
广播(Broadcasting)允许对不同形状的张量进行操作,通过自动扩展它们的形状到兼容的尺寸而不需要复制数据。这遵循以下规则:
如果其中一个张量形状较短,则在左侧填充1,直到两个张量具有相同的维度数:
((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))
如果两个维度相等,或者其中一个维度为1,则这两个维度是兼容的。维度1将被扩展以匹配另一个张量的维度。
((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))
与NumPy的差异¶
C语言整数除法中的舍入规则 Triton中的运算符遵循C语言语义而非Python语义以提高效率。因此,int // int
对于混合符号的整数实现了向零舍入的C语言规则,而不是像Python那样向负无穷舍入。出于同样的原因,取模运算符int % int
(定义为a % b = a - b * (a // b)
)也遵循C语言语义而非Python语义。
可能令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循Python语义。