Shortcuts

梯度检查机制

本笔记概述了gradcheck()gradgradcheck()函数的工作原理。

它将涵盖实值和复值函数的前向和反向模式自动微分(AD)以及高阶导数。 本笔记还涵盖了gradcheck的默认行为以及传递fast_mode=True参数的情况(以下称为快速gradcheck)。

符号和背景信息

在本笔记中,我们将使用以下约定:

  1. xx, yy, aa, bb, vv, uu, ururuiui 是实值向量,而 zz 是一个复值向量,可以重写为两个实值向量的形式 z=a+ibz = a + i b

  2. NNMM 是两个整数,我们将分别用于输入和输出空间的维度。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我们基本的实数到实数函数,使得 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我们的基本复数到实数函数,使得 y=g(z)y = g(z)

对于简单的实数到实数的情况,我们将其写为 JfJ_fff 相关联的雅可比矩阵,大小为 M×NM \times N。 该矩阵包含所有偏导数,使得位置 (i,j)(i, j) 处的条目包含 yixj\frac{\partial y_i}{\partial x_j}。 反向模式自动微分则是对于给定的向量 vv,大小为 MM,计算量 vTJfv^T J_f。 另一方面,前向模式自动微分则是对于给定的向量 uu,大小为 NN,计算量 JfuJ_f u

对于包含复数值的函数,情况要复杂得多。我们在这里只提供概要,完整的描述可以在复数的自动求导中找到。

满足复数可微性(柯西-黎曼方程)的约束对于所有实值损失函数来说过于严格,因此我们选择使用Wirtinger微积分。 在Wirtinger微积分的基础设置中,链式法则需要访问Wirtinger导数(称为WW)和共轭Wirtinger导数(称为CWCW)。 WWCWCW都需要传播,因为在一般情况下,尽管它们的名字如此,但它们并不是彼此的复共轭。

为了避免在反向模式自动微分时传播两个值,我们总是假设正在计算导数的函数要么是一个实值函数,要么是更大实值函数的一部分。这一假设意味着我们在反向传播过程中计算的所有中间梯度也都与实值函数相关联。 在实践中,这一假设在进行优化时并不具有限制性,因为这类问题需要实值目标(因为复数没有自然的顺序)。

在这个假设下,使用 WWCWCW 定义,我们可以证明 W=CWW = CW^*(我们在这里使用 * 表示复共轭),因此实际上只需要其中一个值“通过图向后传播”,因为另一个值可以很容易地恢复。 为了简化内部计算,PyTorch 使用 2CW2 * CW 作为用户请求梯度时它向后传播并返回的值。 与实数情况类似,当输出实际上在 RM\mathcal{R}^M 中时,反向模式自动微分不会计算 2CW2 * CW 而只计算 vT(2CW)v^T (2 * CW) 对于给定的向量 vRMv \in \mathcal{R}^M

对于前向模式自动微分,我们使用类似的逻辑,在这种情况下,假设函数是更大函数的一部分,其输入在 R\mathcal{R} 中。在此假设下,我们可以做出类似的声明,即每个中间结果对应于一个输入在 R\mathcal{R} 中的函数,在这种情况下,使用 WWCWCW 定义,我们可以证明 W=CWW = CW 对于中间函数。 为了确保前向模式和反向模式在单变量函数的基本情况下计算相同的量,前向模式也会计算 2CW2 * CW。 与实际情况类似,当输入实际上在 RN\mathcal{R}^N 中时,前向模式自动微分不会计算 2CW2 * CW 而只计算 (2CW)u(2 * CW) u 对于给定的向量 uRNu \in \mathcal{R}^N

默认反向模式梯度检查行为

实数到实数函数

要测试一个函数 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我们通过两种方式重建完整的雅可比矩阵 JfJ_f 的大小 M×NM \times N:解析法和数值法。 解析版本使用我们的反向模式自动微分,而数值版本使用有限差分。 然后,将两个重建的雅可比矩阵逐元素进行比较以验证相等性。

默认实数输入数值评估

如果我们考虑一维函数的基本情况(N=M=1N = M = 1),那么我们可以使用维基百科文章中的基本有限差分公式。我们使用“中心差分”以获得更好的数值特性:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

这个公式很容易推广到多个输出(M>1M \gt 1),通过让yx\frac{\partial y}{\partial x}成为一个大小为M×1M \times 1的列向量,例如f(x+eps)f(x + eps)。 在这种情况下,上述公式可以原封不动地重复使用,并且只需两次用户函数的评估(即f(x+eps)f(x + eps)f(xeps)f(x - eps))即可近似整个雅可比矩阵。

处理多个输入的情况(N>1N \gt 1)在计算上更为昂贵。在这种情况下,我们依次遍历所有输入,并对xx的每个元素依次应用epseps扰动。这使我们能够逐列重建JfJ_f矩阵。

默认实数输入的解析评估

对于分析评估,我们使用上述事实,即反向模式自动微分计算vTJfv^T J_f。 对于具有单个输出的函数,我们简单地使用v=1v = 1在一次反向传递中恢复完整的雅可比矩阵。

对于具有多个输出的函数,我们采用for循环遍历每个输出,其中每个vv是一个对应于每个输出的独热向量。这允许我们逐行重建JfJ_f矩阵。

复数到实数函数

要测试一个函数 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y 其中 z=a+ibz = a + i b,我们重建包含 2CW2 * CW 的(复数值)矩阵。

默认复杂输入数值评估

考虑最简单的情况,其中 N=M=1N = M = 1 首先。我们从(第3章)这篇研究论文中知道:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

请注意,ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b},在上面的等式中,是 RR\mathcal{R} \to \mathcal{R} 导数。 为了数值评估这些,我们使用上述方法来处理实数到实数的情况。 这使我们能够计算 CWCW 矩阵,然后将其乘以 22

请注意,截至撰写本文时,代码以一种稍微复杂的方式计算此值:

# 代码来自 https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# 此代码块中的符号变化:
# 这里的 s 是上面的 y
# 这里的 x, y 是上面的 a, b

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# 共轭 Wirtinger 导数
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# Wirtinger 导数
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# 由于 grad_out 总是 1,并且 W 和 CW 是彼此的复共轭,最后一行最终计算出 `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`。

默认复杂输入分析评估

由于反向模式AD正好计算了CWCW导数的两次,我们在这里简单地使用了与实数到实数情况相同的技巧,并在有多个实数输出时逐行重建矩阵。

具有复杂输出的函数

在这种情况下,用户提供的函数不符合自动微分(autograd)的假设,即我们为其计算反向自动微分的函数是实值的。这意味着直接对该函数使用自动微分是没有明确定义的。为了解决这个问题,我们将替换函数 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M(其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C}),使用两个函数:hrhrhihi,使得:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

其中 qPq \in \mathcal{P}。 然后我们对 hrhrhihi 进行基本的梯度检查,使用上述实数到实数或复数到实数的情况,具体取决于 P\mathcal{P}

请注意,截至撰写本文时,代码并未显式创建这些函数,而是通过传递grad_out\text{grad\_out}参数到不同的函数,手动执行链式法则与realrealimagimag函数。 当grad_out=1\text{grad\_out} = 1时,我们考虑的是hrhr。 当grad_out=1j\text{grad\_out} = 1j时,我们考虑的是hihi

快速反向模式梯度检查

虽然上述的gradcheck公式非常出色,既能确保正确性又能便于调试,但由于它需要重建完整的雅可比矩阵,因此速度非常慢。 本节介绍了一种在不牺牲正确性的前提下,以更快速度执行gradcheck的方法。 当检测到错误时,可以通过添加特殊逻辑来恢复调试能力。在这种情况下,我们可以运行默认版本,重建完整的矩阵,以向用户提供详细信息。

这里的高层次策略是找到一个标量值,该值可以通过数值和解析方法高效计算,并且能够很好地代表慢速gradcheck计算的完整矩阵,以确保它能够捕捉到Jacobians中的任何差异。

实数到实数函数的快速梯度检查

我们在这里想要计算的标量是 vTJfuv^T J_f u 对于给定的随机向量 vRMv \in \mathcal{R}^M 和一个随机单位范数向量 uRNu \in \mathcal{R}^N

对于数值评估,我们可以高效地计算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后我们在这个向量和vv之间执行点积,以获得感兴趣的标量值。

对于分析版本,我们可以使用反向模式AD来计算 vTJfv^T J_f 直接计算。然后我们与 uu 进行点积运算以获得期望值。

复杂到实数函数的快速梯度检查

与实数到实数的情况类似,我们希望对整个矩阵进行约简。但 2CW2 * CW 矩阵是复数值的,因此在这种情况中,我们将与复数标量进行比较。

由于在数值情况下我们能够高效计算的某些限制,并且为了将数值评估的次数保持在最低限度,我们计算了以下(尽管令人惊讶的)标量值:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^M, urRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速复杂输入数值评估

我们首先考虑如何用数值方法计算 ss。为此,记住我们正在考虑 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,并且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我们将其重写如下:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在这个公式中,我们可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像实数到实数情况下的快速版本一样进行评估。 一旦这些实数值被计算出来,我们就可以在右侧重建复数向量,并与实数值的 vv 向量进行点积。

快速复杂输入分析评估

对于分析情况,事情更简单,我们将公式重写为:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用反向模式自动微分(AD)提供了一种有效的方法来计算 vT(2CW)v^T (2 * CW),然后对实部与 urur 和虚部与 uiui 进行点积运算,最后重建最终的复数标量 ss

为什么不使用复杂的 uu

在这一点上,你可能会想知道为什么我们没有选择一个复杂的 uu 而只是执行了简化 2vTCWu2 * v^T CW u'。 为了深入探讨这一点,在本段中,我们将使用 uu 的复杂版本,记为 u=ur+iuiu' = ur' + i ui'。 使用这种复杂的 uu',问题是在进行数值评估时,我们需要计算:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这将需要对实数到实数的有限差分进行四次评估(是上述方法的两倍)。 由于这种方法没有更多的自由度(实值变量的数量相同),并且我们在这里尝试获得最快的评估速度,因此我们使用上述其他公式。

快速梯度检查对于具有复杂输出的函数

就像在慢速情况下一样,我们考虑两个实值函数,并使用上述适当的规则来处理每个函数。

Gradgradcheck 实现

PyTorch 还提供了一个工具来验证二阶梯度。这里的目标是确保反向传播的实现也是正确可微的,并且计算正确。

此功能通过考虑函数 F:x,vvTJfF: x, v \to v^T J_f 并在此函数上使用上述定义的 gradcheck。 请注意,在这种情况下,vv 只是一个与 f(x)f(x) 相同类型的随机向量。

gradgradcheck 的快速版本是通过对该函数 FF 使用 gradcheck 的快速版本来实现的。