注意
跳至末尾 下载完整示例代码。
融合Softmax¶
在本教程中,您将编写一个融合的softmax操作,对于特定类别的矩阵(那些行可以放入GPU的SRAM中的矩阵),该操作比PyTorch原生操作要快得多。
在此过程中,您将了解:
内核融合对带宽受限操作的优势。
Triton中的归约运算符。
动机¶
针对元素级加法定制GPU内核在教学上很有价值,但在实践中作用有限。 让我们转而考虑一个简单的(数值稳定的)softmax运算案例:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
在PyTorch中原始实现时,计算y = naive_softmax(x)
对于\(x \in R^{M \times N}\)
需要从DRAM读取\(5MN + 2M\)个元素并写回\(3MN + 2M\)个元素。
这显然效率低下;我们更希望有一个自定义的"融合"内核,只需读取
X一次并在芯片上完成所有必要的计算。
这样做只需要读写\(MN\)字节,因此我们
可以预期理论加速比约为4倍(即\((8MN + 4M) / 2MN\))。
torch.jit.script标志旨在自动执行这种"内核融合",
但正如我们稍后将看到的,它仍然远非理想。
计算内核¶
我们的softmax内核工作流程如下:每个程序加载输入矩阵X的一组行,按程序数量进行跨步访问,对其进行归一化处理,然后将结果写回输出Y。
需要注意的是,Triton的一个重要限制是每个块必须包含2的幂次方数量的元素,因此如果我们想要处理任何可能的输入形状,就需要在内部对每行进行"填充"并妥善保护内存操作:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
我们可以创建一个辅助函数,用于为任何给定的输入张量入队内核及其(元)参数。
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
单元测试¶
我们确保在具有不规则行数和列数的矩阵上测试我们的内核。 这将使我们能够验证我们的填充机制是否正常工作。
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
正如预期的那样,结果是相同的。
基准测试¶
这里我们将以输入矩阵的列数作为变量进行性能基准测试——假设行数为4096。
然后将其性能与以下两者进行比较:(1) torch.softmax
和 (2) 上文定义的 naive_softmax
。
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 469.792103 690.479590
1 384.0 662.029674 811.735767
2 512.0 803.774048 933.690037
3 640.0 814.259863 951.310147
4 768.0 887.911139 1014.347900
5 896.0 941.744268 1074.276345
6 1024.0 1010.800111 1124.934718
7 1152.0 1113.810347 610.240230
8 1280.0 1149.105964 669.038125
9 1408.0 1162.764033 720.335250
10 1536.0 1184.085504 778.828310
11 1664.0 1219.239818 812.094662
12 1792.0 1231.477099 859.448345
13 1920.0 1256.886294 908.037628
14 2048.0 1277.718663 952.991168
15 2176.0 1253.725125 971.816070
16 2304.0 1268.320104 1007.719434
17 2432.0 1299.124105 1055.950085
18 2560.0 1300.365139 1082.730090
19 2688.0 1306.441687 1101.276232
20 2816.0 1319.990833 1132.168186
21 2944.0 1320.216423 1168.636828
22 3072.0 1350.997882 1181.619245
23 3200.0 1354.016834 1188.581583
24 3328.0 1351.098643 1221.780127
25 3456.0 1375.607968 1245.063417
26 3584.0 1376.988531 1261.092737
27 3712.0 1383.343078 1266.593209
28 3840.0 1388.555439 1301.730701
29 3968.0 1384.456602 1318.622662
30 4096.0 1396.556181 1323.277671
31 4224.0 1322.864751 1160.559436
32 4352.0 1334.810748 1176.315840
33 4480.0 1343.671592 1185.310404
34 4608.0 1359.287808 1195.690748
35 4736.0 1354.841830 1199.656322
36 4864.0 1369.997335 1221.139101
37 4992.0 1368.520261 1235.965433
38 5120.0 1371.279527 1254.102085
39 5248.0 1368.179618 1260.120516
40 5376.0 1373.936688 1284.889451
41 5504.0 1371.394573 1296.271412
42 5632.0 1377.587344 1315.503658
43 5760.0 1388.353173 1326.814807
44 5888.0 1387.030627 1342.728182
45 6016.0 1392.844652 1354.678115
46 6144.0 1399.846496 1370.978600
47 6272.0 1411.031522 1372.748209
48 6400.0 1414.164291 1386.946932
49 6528.0 1406.208670 1393.551380
50 6656.0 1419.471920 1405.945793
51 6784.0 1411.009759 1417.065023
52 6912.0 1421.251436 1425.455983
53 7040.0 1413.784194 1431.433972
54 7168.0 1421.633392 1436.439611
55 7296.0 1426.018564 1443.456418
56 7424.0 1421.495264 1444.651309
57 7552.0 1418.828284 1451.538160
58 7680.0 1427.975392 1462.513487
59 7808.0 1422.943836 1463.951685
60 7936.0 1427.397610 1466.297458
61 8064.0 1429.727475 1470.760793
62 8192.0 1435.952633 1482.562485
63 8320.0 1384.598040 1403.984147
64 8448.0 1377.945445 1404.082355
65 8576.0 1389.102853 1396.690335
66 8704.0 1387.508845 1402.348581
67 8832.0 1379.882152 1402.550067
68 8960.0 1389.749201 1410.167114
69 9088.0 1399.457750 1416.924510
70 9216.0 1395.016691 1424.998244
71 9344.0 1394.858859 1423.694109
72 9472.0 1392.206411 1437.172627
73 9600.0 1386.286661 1432.117636
74 9728.0 1394.420457 1438.556933
75 9856.0 1404.129055 1441.527918
76 9984.0 1390.927497 1449.154442
77 10112.0 1404.100537 1456.021603
78 10240.0 1410.556721 1465.939824
79 10368.0 1406.774314 1466.218845
80 10496.0 1404.666444 1466.183818
81 10624.0 1403.211702 1466.628643
82 10752.0 1399.436409 1471.135428
83 10880.0 1394.578227 1481.226466
84 11008.0 1413.670671 1478.431178
85 11136.0 1412.445292 1485.894046
86 11264.0 1424.475331 1484.586567
87 11392.0 1412.601028 1488.308142
88 11520.0 1415.350157 1494.075980
89 11648.0 1416.603651 1497.847659
90 11776.0 1421.355225 1498.840731
91 11904.0 1433.294484 1508.954920
92 12032.0 1412.955256 1508.118766
93 12160.0 1413.539787 1510.814239
94 12288.0 1428.464053 1392.213394
95 12416.0 1436.145762 1391.470332
96 12544.0 1436.673855 1392.656796
97 12672.0 1442.793392 1393.179870
- In the above plot, we can see that:
Triton比Torch JIT快4倍。这证实了我们的怀疑,即Torch JIT在这里没有进行任何融合优化。
Triton明显比
torch.softmax
更快——而且更易读、易懂和维护。 但请注意PyTorch的softmax操作更通用,可以在任何形状的张量上工作。
脚本总运行时间: (0 分钟 23.177 秒)