triton.language.dot¶
- triton.language.dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32)¶
返回两个块的矩阵乘积。
这两个块必须都是二维或三维的,并且具有兼容的内部维度。 对于三维块,tl.dot执行批量矩阵乘积, 其中每个块的第一个维度代表批量维度。
- Parameters:
input (2D或3D张量,标量类型为{
int8,float8_e5m2,float16,bfloat16,float32}) - 第一个待相乘的张量。other (2D或3D张量,标量类型为{
int8,float8_e5m2,float16,bfloat16,float32}) - 要相乘的第二个张量。acc (2D或3D张量,标量类型为{
float16,float32,int32}) - 累加器张量。如果不为None,结果将被添加到此张量中。input_precision (字符串。NVIDIA可用选项:
"tf32","tf32x3","ieee"。默认值:"tf32"。AMD可用选项:"ieee", (仅CDNA3)"tf32"。) – 如何为f32 x f32运算使用张量核心。如果设备没有张量核心或输入不是f32类型,此选项将被忽略。对于具有张量核心的设备,默认精度为tf32。allow_tf32 – 已弃用。 如果为true,input_precision将被设置为"tf32"。
input_precision和allow_tf32中只能指定一个(即至少有一个必须是None)。