低内存Dropout

在本教程中,您将编写一个内存高效的dropout实现,其状态将由单个int32种子组成。这与更传统的dropout实现不同,后者的状态通常由与输入形状相同的位掩码张量组成。

在此过程中,您将了解:

  • 使用PyTorch实现Dropout时原生方法的局限性。

  • Triton中的并行伪随机数生成。

基准

dropout算子最初在[SRIVASTAVA2014]中被提出,作为一种在低数据量情况下(即正则化)提升深度神经网络性能的方法。

它接收一个向量作为输入,并输出一个形状相同的向量。输出中的每个标量有\(p\)的概率被置为零,否则将保留输入值不变。这种机制迫使网络即使在只有\(1 - p\)比例的输入标量可用时仍能保持良好性能。

在评估阶段,我们希望充分利用网络的全部能力,因此设置\(p=0\)。直观来看,这会增加输出的范数(这可能带来负面影响,例如可能导致输出softmax温度人为降低)。为了防止这种情况,我们将输出乘以\(\frac{1}{1 - p}\),这样无论dropout概率如何,都能保持范数一致。

让我们首先看一下基准实现。

import tabulate
import torch

import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def _dropout(
    x_ptr,  # pointer to the input
    x_keep_ptr,  # pointer to a mask of 0s and 1s
    output_ptr,  # pointer to the output
    n_elements,  # number of elements in the `x` tensor
    p,  # probability that an element of `x` is changed to zero
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
x = torch.randn(size=(10, ), device=DEVICE)
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32)
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist(),
]))
/home/runner/_work/triton/triton/python/triton/language/semantic.py:1643: UserWarning: tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got int32
  warnings.warn(
---------  ---------  -------  --------  -------  --------  -------  ---------  --------  --------  -------
input      -0.940469  0.17792  0.529538  0.13197  0.135063  1.64092  -0.309264  0.618883  -1.53066  0.46037
keep mask   0         0        0         0        0         1         0         0          1        1
output      0         0        0         0        0         3.28183   0         0         -3.06132  0.92074
---------  ---------  -------  --------  -------  --------  -------  ---------  --------  --------  -------

种子化dropout

The above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly we need to store the dropout mask for backpropagation. Secondly, dropout state management can get very tricky when using recompute/checkpointing (e.g. see all the notes about preserve_rng_state in https://pytorch.org/docs/stable/checkpoint.html). In this tutorial we’ll describe an alternative implementation that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management of persisting randomness across multiple invocations of the kernel.

在Triton中生成伪随机数非常简单!本教程我们将使用triton.language.rand函数,该函数根据给定的种子和一个int32偏移量块,生成均匀分布在[0, 1)区间内的float32值块。如果您需要,Triton还提供其他随机数生成策略

注意

Triton 实现的 PRNG 基于 Philox 算法(详见 [SALMON2011])。

让我们把这些内容整合起来。

@triton.jit
def _seeded_dropout(
    x_ptr,
    output_ptr,
    n_elements,
    p,
    seed,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10, ), device=DEVICE)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(
    tabulate.tabulate([
        ["input"] + x.tolist(),
        ["output (seed = 123)"] + output.tolist(),
        ["output (seed = 123)"] + output2.tolist(),
        ["output (seed = 512)"] + output3.tolist(),
    ]))
-------------------  -------  ---------  ---------  -------  --------  --------  -------  --------  -------  ---------
input                1.48333  -0.239537  -0.640795  1.62631  0.263036  -0.71516  1.99474  -1.09546  1.81107  -0.170083
output (seed = 123)  0        -0.479074   0         0        0         -1.43032  0         0        3.62215  -0.340165
output (seed = 123)  0        -0.479074   0         0        0         -1.43032  0         0        3.62215  -0.340165
output (seed = 512)  0         0         -1.28159   3.25261  0         -1.43032  3.98947   0        0         0
-------------------  -------  ---------  ---------  -------  --------  --------  -------  --------  -------  ---------

瞧!我们有了一个triton内核,只要种子相同就能应用相同的dropout掩码! 如果您想进一步探索伪随机性在GPU编程中的应用,我们鼓励您 研究python/triton/language/random.py

练习

  1. 扩展内核以操作矩阵并使用种子向量 - 每行一个种子。

  2. 添加对跨步的支持。

  3. (挑战) 实现一个稀疏Johnson-Lindenstrauss变换的内核,该内核每次运行时使用种子动态生成投影矩阵。

参考文献

[SALMON2011]

John K. Salmon, Mark A. Moraes, Ron O. Dror, 和 David E. Shaw, "并行随机数:简单如1, 2, 3", 2011

[SRIVASTAVA2014]

Nitish Srivastava、Geoffrey Hinton、Alex Krizhevsky、Ilya Sutskever和Ruslan Salakhutdinov,《Dropout:防止神经网络过拟合的简单方法》,JMLR 2014

脚本总运行时间: (0 分钟 0.749 秒)

Gallery generated by Sphinx-Gallery