跳转到内容

算法微分背景

正如每个计算机程序都是由一系列简单的算术运算组成的,即

\[ a \rightarrow b \rightarrow \ldots \rightarrow u \rightarrow v \rightarrow \ldots \rightarrow z \]

其中输入\(a\)会分阶段修改以获得最终输出\(z\)。当每个运算的单独导数已知时,可以通过链式法则的递归应用来计算最终导数。这种方法被称为算法微分,包含两种模式:前向(或正切线性)模式,即从输入到输出;以及伴随(或反向)模式,即从输出到输入。

本节将介绍计算计算机程序导数的理论基础。我们先回顾传统的有限差分法(通常称为扰动法),然后介绍前向和伴随算法微分。

有限差分法

计算这些导数的传统方法是采用有限差分近似。也就是说,依次对每个输入变量进行微调,然后利用结果的变化来估计敏感度:

\[ \begin{align} \frac{\partial f(x, \pmb{y})}{\partial x} &= \lim_{h\rightarrow 0}\frac{f(x+h, \pmb{y}) - f(x,\pmb{y})}{h} \\ \frac{\partial f(x, \pmb{y})}{\partial x} &= \lim_{h\rightarrow 0} \frac{f(x+h, \pmb{y}) - f(x-h, \pmb{y})}{2h} \end{align} \]

其中\(f(x, \pmb{y})\)是我们感兴趣的对输入参数\(x\)求导的函数。向量值参数\(\pmb{y}\)表示其余函数参数。第一个方程代表前向有限差分,需要两次函数求值。第二个方程给出中心有限差分,可能具有更高精度,计算导数需要两次函数求值,另需一次求值来获得函数值。

在实践中,选择\(h\)的值需要足够小以逼近理论极限,但又要足够大以使结果产生超出典型数值误差水平的可检测变化。显然,这一选择会影响近似计算的准确性。

此外,这种方法意味着函数需要被评估一次以获取结果,并对我们感兴趣的每个导数再进行一次评估。当需要计算多个导数时,这会导致整体计算复杂度显著增加。

因此,有限差分方法存在精度和性能上的局限性。

前向模式

理论

前向模式将 \(\dot{u}\) 定义为 \(u\) 关于 \(a\) 的导数,即

\[ \dot{u} = \frac{\partial u}{\partial a} \]

应用微分的链式法则并假设中间变量是向量,\(\dot{v}\)的元素可以计算为

\[ \dot{v}_i = \sum_j \frac{\partial v_i}{\partial u_{j}} \dot{u}_j \]

将这个原理应用到从输入到输出的每个操作步骤链中,就可以计算出\(\dot{z}\)的值。这就是算法微分的前向模式

对于一个函数\(f,{:},\mathbb{R}^n,{\rightarrow},\mathbb{R}^m\),前向模式的一次应用可以得到所有\(m\)个输出相对于一个输入参数的敏感度。需要重新计算\(n\)次才能获得所有敏感度。计算成本在输出变量数量\(m\)上是恒定的,在输入变量数量\(n\)上是线性的。

示例

我们通过示例函数展示前向模式:

\[ z = \sin x_1 + x_1 x_2 \]

这可以在计算机程序中实现为:

a = sin(x1);
b = x1 * x2;
z = a + b;

我们关注的是在输入值\(x_1 = \pi\)\(x_2 = 2\)时关于\(x_1\)的导数。下图展示了前向模式算法微分如何应用于这个问题:

Forward mode example

左侧我们看到代表方程的计算图,右侧表格展示了执行步骤。

在步骤0中,我们初始化输入值并设定这些输入的导数种子。由于我们关注的是关于\(x_1\)的导数,因此将其导数设为1,其他设为0。

接下来我们通过正弦函数计算\(a\)。当\(a\)的值为零时,\(\dot{a}\)的计算方法是将正弦函数对\(x_1\)的偏导数(即余弦函数)与\(\dot{x_1}\)相乘。这样得到的值为-1。

在下一步中,\(b\)的值照常计算,而\(\dot{b}\)的计算方式与\(\dot{a}\)类似,但这次同时依赖于\(\dot{x_1}\)\(\dot{x_2}\)。最终得到的值为2。

最终语句将\(a\)\(b\)相加,得到结果为\(2\pi\)。为了计算\(\dot{z}\),我们可以看到\(\dot{a}\)\(\dot{b}\)可以直接相加,因为它们的偏导数都是1。最终得到的导数为1。

因此:

\[ \left.\frac{\partial z}{\partial x_1}\right|_{(\pi,2)} = 1 \]

这一点可以通过解析方法轻松验证。

伴随模式

理论

伴随模式以反向方式应用链式法则,从输出到输入。使用标准符号,我们定义

\[ \bar{u}_i = \frac{\partial z}{\partial u_i} \]

其中 \(i\) 是向量 \(\pmb{u}\) 中的索引。应用链式法则可得

\[ \frac{\partial z}{\partial u_i} = \sum_j \frac{\partial z}{\partial v_j} \frac{\partial v_j}{\partial u_i} \]

这导致了伴随模式方程

\[ \bar{u}_i = \sum_j \frac{\partial v_j}{\partial u_i} \bar{v}_j \]

设定种子值\(\bar{z} = 1\)后,可以从输出到输入逐步应用伴随模式方程,得到\(\bar{\pmb{a}}\),即输出\(z\)相对于每个输入变量\(\pmb{a}\)的导数。

对于一个函数\(f,{:},\mathbb{R}^n,{\rightarrow},\mathbb{R}^m\),伴随模式给出了一个输出相对于所有\(n\)个输入参数的敏感度。需要重新评估\(m\)次才能获得所有敏感度。计算成本在输入变量数量\(n\)上是恒定的,在输出变量数量\(m\)上是线性的。

示例

我们使用与上述相同的例子来说明伴随模式:

\[ z = \sin x_1 + x_1 x_2 \]

实现方式:

a = sin(x1);
b = x1 * x2;
z = a + b;

通过伴随模式,我们可以在单次执行中同时获取输出的两个偏导数,输入值为\(x_1 = \pi\)\(x_2 = 2\)。如下图所示:

Adjoint mode example

当伴随模式从输出回溯到输入时,我们照常执行完整的数值计算,直到得到\(z\)的输出值为\(2\pi\)

然后在最后一步我们将\(z\)的伴随设为1,并反向计算输入的伴随值。

在步骤2中,我们可以通过将\(z\)的伴随乘以\(z\)关于\(b\)的偏导数(即1)来计算\(b\)的伴随。

同样的操作在步骤1中执行,用于计算\(a\)的伴随,结果同样得到1。

然后通过将\(b\)关于\(x_2\)的偏导数与\(b\)的伴随相乘来计算\(x_2\)的伴随,得到值\(\pi\)

同样的方法被用来计算\(x_1\)的伴随值,得到结果为1。

因此,我们感兴趣的两个导数是:

\[ \begin{align} \left.\frac{\partial z}{\partial x_1}\right|_{(\pi,2)} &= 1 &,&& \left.\frac{\partial z}{\partial x_2}\right|_{(\pi,2)} &= \pi \end{align} \]

这可以通过解析方法轻松验证。

高阶导数

高阶导数可以通过嵌套上述原理获得。例如,在伴随模式上应用前向模式算法微分可以得到二阶导数。这种方法可以扩展到任意阶数。