torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[源代码]¶
设置float32矩阵乘法的内部精度。
在较低精度下运行 float32 矩阵乘法可能会显著提高性能,并且在某些程序中,精度的损失影响可以忽略不计。
支持三种设置:
“最高”,float32 矩阵乘法使用 float32 数据类型(24 尾数位,其中 23 位显式存储)进行内部计算。
“高”,float32 矩阵乘法要么使用 TensorFloat32 数据类型(显式存储 10 位尾数),要么将每个 float32 数字视为两个 bfloat16 数字的和(大约 16 位尾数,显式存储 14 位),如果相应的快速矩阵乘法算法可用的话。否则,float32 矩阵乘法将计算为如果精度为“最高”。有关 bfloat16 方法的更多信息,请参见下文。
“medium”,float32 矩阵乘法使用 bfloat16 数据类型(8 尾数位,其中 7 位显式存储)进行内部计算,如果可以使用该数据类型的快速矩阵乘法算法。否则,float32 矩阵乘法将按照“high”精度计算。
当使用“高”精度时,float32 乘法可能会使用基于 bfloat16 的算法,该算法比简单地截断到某些较小的尾数位(例如,TensorFloat32 为 10,显式存储的 bfloat16 为 7)更为复杂。有关此算法的完整描述,请参阅 [Henry2019]。简要解释如下,第一步是认识到我们可以将一个 float32 数字完美地编码为三个 bfloat16 数字的和(因为 float32 有 23 个尾数位,而 bfloat16 显式存储了 7 个,并且两者具有相同数量的指数位)。这意味着两个 float32 数字的乘积可以精确地由九个 bfloat16 数字的乘积之和给出。然后,我们可以通过丢弃其中一些乘积来用速度换取精度。“高”精度算法特别保留了三个最重要的乘积,这恰好排除了所有涉及任一输入的最后 8 个尾数位的乘积。这意味着我们可以将输入表示为两个 bfloat16 数字的和,而不是三个。由于 bfloat16 融合乘加(FMA)指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度进行三次乘法和两次加法比使用 float32 精度进行一次乘法更快。
注意
这不会改变float32矩阵乘法的输出数据类型,它控制矩阵乘法的内部计算方式。
注意
这不会改变卷积操作的精度。其他标志,如torch.backends.cudnn.allow_tf32,可能会控制卷积操作的精度。
注意
此标志目前仅影响一种本地设备类型:CUDA。 如果设置为“高”或“中”,则在计算float32矩阵乘法时将使用TensorFloat32数据类型,相当于设置 torch.backends.cuda.matmul.allow_tf32 = True。当设置为“最高”(默认值)时,内部计算将使用float32数据类型,相当于设置 torch.backends.cuda.matmul.allow_tf32 = False。
- Parameters
精度 (str) – 可以设置为“最高”(默认)、“高”或“中”(见上文)。