自动求导机制¶
本笔记将概述autograd的工作原理和记录操作的方式。虽然严格来说理解这些并不是必须的,但我们建议熟悉它,因为它将帮助您编写更高效、更简洁的程序,并有助于调试。
autograd 如何编码历史记录¶
Autograd 是一个反向自动微分系统。从概念上讲,autograd 记录了一个图,记录了在执行操作时创建数据的所有操作,从而生成一个有向无环图,其叶子是输入张量,根是输出张量。通过从根到叶追踪这个图,您可以使用链式法则自动计算梯度。
在内部,autograd 将此图表示为
Function
对象(实际上是表达式)的图,可以
apply()
来计算评估图的结果。在计算前向传播时,autograd 同时执行请求的计算并构建一个图,表示计算梯度的函数(每个 torch.Tensor
的 .grad_fn
属性是此图的入口点)。当前向传播完成时,我们在反向传播中评估此图以计算梯度。
需要注意的一点是,图表在每次迭代时都会从头开始重新创建,而这正是允许使用任意Python控制流语句的原因,这些语句可以在每次迭代时改变图表的整体形状和大小。你不需要在启动训练之前编码所有可能的路径 - 你运行的是什么,你就可以对其进行微分。
保存的张量¶
一些操作需要在正向传播过程中保存中间结果,以便在反向传播过程中执行。例如,函数 保存输入 以计算梯度。
在定义自定义 Python Function
时,您可以使用
save_for_backward()
在正向传播过程中保存张量,并使用
saved_tensors
在反向传播过程中检索它们。有关更多信息,请参阅 扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()
),张量会根据需要自动保存。你可以探索(出于教育或调试目的)某个 grad_fn
保存了哪些张量,方法是查找以 _saved
前缀开头的属性。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # 真
print(x is y.grad_fn._saved_self) # 真
在前面的代码中,y.grad_fn._saved_self
指向与 x 相同的张量对象。
但情况并非总是如此。例如:
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # 真
print(y is y.grad_fn._saved_result) # 假
在底层,为了防止引用循环,PyTorch在保存时打包了张量,并在读取时将其解包到不同的张量中。这里,通过访问y.grad_fn._saved_result
得到的张量是一个不同于y
的张量对象(但它们仍然共享相同的存储)。
一个张量是否会被打包到不同的张量对象中,取决于它是否是其自身grad_fn的输出,这是一个可能会改变的实现细节,用户不应依赖于此。
您可以通过保存张量的钩子来控制PyTorch如何进行打包/解包。
非可微函数的梯度¶
使用自动微分进行梯度计算仅在每个基本函数都可微时有效。
不幸的是,我们在实践中使用的许多函数并不具备这一特性(例如,relu
或 sqrt
在 0
处)。
为了尝试减少不可微函数的影响,我们通过按以下顺序应用规则来定义基本操作的梯度:
如果函数是可微的,并且在当前点存在梯度,则使用它。
如果函数是凸的(至少是局部凸的),使用最小范数的次梯度(这是最陡下降方向)。
如果函数是凹的(至少是局部凹的),使用最小范数的超梯度(考虑 -f(x) 并应用前一点)。
如果函数已定义,通过连续性定义当前点的梯度(注意,这里可能出现
inf
,例如对于sqrt(0)
)。如果存在多个可能的值,则任意选择一个。如果函数未定义(例如
sqrt(-1)
、log(-1)
或大多数函数在输入为NaN
时),则用作梯度的值是任意的(我们也可能会引发错误,但这并不保证)。大多数函数将使用NaN
作为梯度,但由于性能原因,某些函数将使用其他值(例如log(-1)
)。如果函数不是一个确定性映射(即它不是一个数学函数),它将被标记为不可微分。这将在反向传播中导致错误,如果在需要梯度的张量上使用它,并且不在
no_grad
环境中。
局部禁用梯度计算¶
Python 提供了几种机制来在本地禁用梯度计算:
要禁用整个代码块的梯度,可以使用上下文管理器,如no-grad模式和推理模式。
对于更细粒度的排除子图从梯度计算中,可以设置张量的requires_grad
字段。
下面,除了讨论上述机制外,我们还描述了评估模式(nn.Module.eval()
),这是一种不用于禁用梯度计算的方法,但由于其名称,常常与上述三种方法混淆。
设置 requires_grad
¶
requires_grad
是一个标志,默认为 false 除非被包裹在 nn.Parameter
中,它允许从梯度计算中细粒度地排除子图。它在正向和反向传播中都起作用:
在前向传播过程中,只有当至少一个输入张量需要梯度时,操作才会被记录在反向图中。在反向传播过程中(.backward()
),只有具有requires_grad=True
的叶子张量才会将梯度累积到它们的.grad
字段中。
需要注意的是,尽管每个张量都有这个标志,但设置它只对叶子张量(没有grad_fn
的张量,例如nn.Module
的参数)有意义。非叶子张量(具有grad_fn
的张量)是与反向传播图相关联的张量。因此,它们的梯度将作为计算需要梯度的叶子张量的中间结果。根据这个定义,很明显所有非叶子张量将自动具有require_grad=True
。
设置 requires_grad
应该是你控制模型中哪些部分参与梯度计算的主要方式,例如,如果你需要在模型微调期间冻结预训练模型的一部分。
要冻结模型的一部分,只需对不希望更新的参数应用 .requires_grad_(False)
。正如上面所述,由于使用这些参数作为输入的计算不会在前向传播中被记录,因此它们在反向传播中不会更新其 .grad
字段,因为它们从一开始就不会成为反向图的一部分,正如所期望的那样。
因为这是一个常见的模式,requires_grad
也可以在模块级别通过 nn.Module.requires_grad_()
来设置。当应用于模块时,.requires_grad_()
会影响模块的所有参数(这些参数默认情况下具有 requires_grad=True
)。
梯度模式¶
除了设置requires_grad
之外,还有三种可以从Python中选择的梯度模式,这些模式可以影响PyTorch中的计算在autograd内部的处理方式:默认模式(梯度模式)、无梯度模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。
模式 |
排除操作不被记录在反向图中 |
跳过额外的自动求导跟踪开销 |
启用模式时创建的张量可以在稍后的梯度模式中使用 |
示例 |
---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
无梯度 |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理,模型评估 |
默认模式(梯度模式)¶
“默认模式”是指在没有启用其他模式(如no-grad模式和推理模式)时我们隐含所处的模式。与“no-grad模式”相对比,默认模式有时也被称为“grad模式”。
关于默认模式最重要的一点是,这是唯一一种requires_grad
生效的模式。在其他两种模式中,requires_grad
总是被覆盖为False
。
无梯度模式¶
在无梯度模式下的计算行为就像所有输入都不需要梯度一样。
换句话说,即使在有无梯度模式下有输入设置了require_grad=True
,计算也不会被记录在反向图中。
当你需要执行不应被autograd记录的操作,但仍希望稍后在grad模式下使用这些计算的输出时,启用no-grad模式。这个上下文管理器使得无需临时将张量设置为requires_grad=False
,然后再设置回True
,就能方便地为一段代码或函数禁用梯度。
例如,在编写优化器时,无梯度模式可能非常有用:在进行训练更新时,您希望就地更新参数,而不让更新被自动梯度记录。您还打算在下一个前向传递中使用更新后的参数进行梯度模式下的计算。
在torch.nn.init中的实现也依赖于无梯度模式,当初始化参数时,以避免在原地更新初始化参数时自动梯度跟踪。
推理模式¶
推理模式是无梯度模式的极端版本。就像在无梯度模式中一样,推理模式中的计算不会记录在反向图中,但启用推理模式将允许PyTorch进一步加速您的模型。这种更好的运行时有一个缺点:在推理模式中创建的张量在退出推理模式后将无法用于自动梯度记录的计算中。
在执行不需要记录在反向图中的计算时启用推理模式,并且你不打算在任何后续需要自动求导记录的计算中使用推理模式中创建的张量。
建议您在不需要自动梯度跟踪的代码部分(例如,数据处理和模型评估)中尝试使用推理模式。如果它能在您的用例中开箱即用,这将是一个免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否没有在退出推理模式后,在自动梯度记录的计算中使用在推理模式中创建的张量。如果您的用例中无法避免这种情况,您可以随时切换回无梯度模式。
有关推理模式的详细信息,请参阅 推理模式。
有关推理模式的实现细节,请参见 RFC-0011-InferenceMode。
评估模式 (nn.Module.eval()
)¶
评估模式并不是一种在本地禁用梯度计算的机制。 无论如何,它在这里被提及是因为有时会被误认为是这样一种机制。
从功能上讲,module.eval()
(或等效地 module.train(False)
)与无梯度模式和推理模式完全正交。model.eval()
如何影响您的模型完全取决于您的模型中使用的特定模块以及它们是否定义了任何训练模式特定的行为。
如果你的模型依赖于诸如torch.nn.Dropout
和
torch.nn.BatchNorm2d
等模块,这些模块的行为可能会根据训练模式而有所不同,例如,为了避免在验证数据上更新BatchNorm的运行统计数据,你需要负责调用model.eval()
和model.train()
。
建议您在训练时始终使用 model.train()
,在评估模型(验证/测试)时使用 model.eval()
,即使您不确定您的模型是否具有训练模式特定的行为,因为您使用的模块可能会更新为在训练和评估模式下表现不同。
使用autograd的就地操作¶
在 autograd 中支持就地操作是一个困难的问题,我们不鼓励在大多数情况下使用它们。Autograd 的积极缓冲区释放和重用使其非常高效,并且很少有场合通过就地操作显著降低内存使用。除非你在内存压力较大的情况下操作,否则你可能永远不需要使用它们。
限制就地操作适用性的主要原因有两个:
就地操作可能会覆盖计算梯度所需的值。
每个就地操作都需要实现重写计算图。非就地版本只是分配新对象并保持对旧图的引用,而就地操作则需要将所有输入的创建者更改为表示此操作的
Function
。这可能会很棘手,特别是如果有许多张量引用相同的存储(例如,通过索引或转置创建的),并且如果修改后的输入的存储被任何其他Tensor
引用,就地函数将引发错误。
就地正确性检查¶
每个张量都保持一个版本计数器,该计数器在每次操作中被标记为脏时递增。当一个函数保存任何张量以进行反向传播时,它们包含的张量的版本计数器也会被保存。一旦你访问
self.saved_tensors
,它会被检查,如果它大于保存的值,则会引发错误。这确保了如果你在使用就地函数且没有看到任何错误,你可以确定计算的梯度是正确的。
多线程自动微分¶
autograd 引擎负责运行所有必要的反向操作以计算反向传播。本节将描述所有细节,帮助您在多线程环境中充分利用它。(这仅与 PyTorch 1.6+ 相关,因为之前版本的行为有所不同。)
用户可以使用多线程代码(例如Hogwild训练)来训练他们的模型,并且不会阻塞在并发反向计算上,示例代码可以是:
# 定义一个用于不同线程的训练函数
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# 前向传播
y = (x + 3) * (x + 4) * 0.5
# 反向传播
y.sum().backward()
# 潜在的优化器更新
# 用户编写自己的线程代码来驱动train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意用户应了解的一些行为:
CPU上的并发¶
当您在CPU上通过Python或C++ API在多个线程中运行backward()
或grad()
时,您期望看到额外的并发性,而不是在执行过程中按特定顺序序列化所有反向调用(PyTorch 1.6之前的行为)。
非确定性¶
如果你从多个线程并发调用 backward()
并且有共享输入(即 Hogwild CPU 训练),那么应该预期会出现不确定性。
这可能发生是因为参数在线程之间自动共享,因此,多个线程可能访问并尝试在梯度累积期间累积相同的 .grad
属性。这在技术上是不安全的,并且可能会导致竞争条件,结果可能无效。
开发具有共享参数的多线程模型的用户应牢记线程模型,并应理解上述问题。
功能性 API torch.autograd.grad()
可以用于计算梯度,而不是使用 backward()
以避免不确定性。
图保留¶
如果自动求导图的一部分在多个线程之间共享,即首先在单个线程中运行前向传播的一部分,然后在多个线程中运行第二部分,那么图的第一部分是共享的。在这种情况下,不同的线程在同一个图上执行grad()
或backward()
可能会出现问题,即一个线程在运行时破坏了图,而另一个线程在这种情况下会崩溃。自动求导会向用户抛出错误,类似于在没有retain_graph=True
的情况下调用backward()
两次,并让用户知道他们应该使用retain_graph=True
。
Autograd 节点的线程安全¶
由于 Autograd 允许调用线程驱动其反向执行以实现潜在的并行性,因此确保在 CPU 上进行线程安全处理非常重要,特别是在并行调用 backward()
时,这些调用共享 GraphTask 的部分或全部内容。
自定义 Python autograd.Function
由于 GIL 的原因,自动是线程安全的。
对于内置的 C++ Autograd 节点(例如 AccumulateGrad, CopySlices)和自定义
autograd::Function
,Autograd 引擎使用线程互斥锁来确保
可能具有状态写/读的 autograd 节点的线程安全。
C++钩子没有线程安全¶
Autograd 依赖用户编写线程安全的 C++ 钩子。如果你想在多线程环境中正确应用钩子,你需要编写适当的线程锁定代码,以确保钩子是线程安全的。
复数的自动微分¶
简而言之:
当你使用 PyTorch 对具有复数域和/或复数陪域的任何函数 进行微分时, 梯度是在假设该函数是更大实值 损失函数 的一部分的情况下计算的。计算的梯度是 (注意 z 的共轭),其负值正是梯度下降算法中使用的最陡下降方向。因此,所有现有的优化器都可以直接与复数参数一起使用。
此约定与TensorFlow的复杂微分约定相匹配,但与JAX不同(JAX计算 )。
如果你有一个内部使用复数运算的实数到实数函数,这里的约定并不重要:你总是会得到与仅使用实数运算实现时相同的结果。
如果你对数学细节感到好奇,或者想知道如何在PyTorch中定义复杂的导数,请继续阅读。
什么是复数导数?¶
复数可微性的数学定义采用了导数的极限定义,并将其推广到复数运算。考虑一个函数 ,
其中 和 是两个变量实值函数, 并且 是虚数单位。
使用导数定义,我们可以写成:
为了使这个极限存在,不仅必须和必须 是实可微的,而且还必须满足柯西-黎曼方程。 换句话说:为实部和虚部步长()计算的极限 必须相等。这是一个更严格的条件。
复可微函数通常被称为全纯函数。它们表现良好,具有您从实可微函数中看到的所有良好性质,但在优化领域中实际上没有用处。对于优化问题,研究社区仅使用实值目标函数,因为复数不属于任何有序域,因此具有复值损失并没有太大意义。
事实证明,没有任何有趣的实值目标函数满足柯西-黎曼方程。因此,同态函数理论不能用于优化,大多数人因此使用维尔廷微积分。
Wirtinger 微积分开始发挥作用 …¶
所以,我们有一个关于复数可微性和全纯函数的重要理论,但我们根本无法使用它,因为许多常用的函数并不是全纯的。可怜的数学家该怎么办呢?Wirtinger观察到,即使不是全纯的,也可以将其重写为一个二元函数,它总是全纯的。这是因为的实部和虚部可以表示为和的函数:
Wirtinger calculus suggests to study 而不是,如果 是实可微的(另一种思考方式是将其视为坐标系的变化,从 到 。)这个函数有偏导数 和 。我们可以使用链式法则来建立这些偏导数与关于 的实部和虚部的偏导数之间的关系。
从上述方程中,我们得到:
这是你在维基百科上可以找到的Wirtinger微积分的经典定义。
这一变化带来了很多美好的结果。
首先,柯西-黎曼方程简化为表示 (也就是说,函数 可以完全用 表示,而不需要参考 )。
另一个重要的(且有些反直觉的)结果,我们将在后面看到,当我们对实值损失进行优化时,在更新变量时应该采取的步骤由 给出(不是 )。
更多阅读,请查看:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger微积分在优化中有什么用处?¶
音频及其他领域的研究人员通常使用梯度下降法来优化具有复杂变量的实值损失函数。通常,这些人将实部和虚部视为可以更新的独立通道。对于步长 和损失 ,我们可以在 中写出以下方程:
这些方程如何转化为复数空间 ?
非常有趣的事情发生了:Wirtinger 微积分告诉我们,我们可以将上述复变量更新公式简化为仅涉及共轭 Wirtinger 导数 ,这正是我们在优化中所采取的步骤。
因为共轭Wirtinger导数为我们提供了实值损失函数的正确步长,PyTorch在您对具有实值损失的函数进行微分时会提供此导数。
PyTorch如何计算共轭Wirtinger导数?¶
通常,我们的导数公式将 grad_output 作为输入, 表示我们已经计算过的传入的向量-雅可比积,即,,其中 是整个计算的损失(产生实际损失), 是我们函数的输出。这里的目标是计算 ,其中 是函数的输入。事实证明,在实际损失的情况下,我们可以 仅 计算 , 尽管链式法则意味着我们还需要 能够访问 。如果你想 跳过这个推导,请查看本节中的最后一个方程 然后跳到下一节。
让我们继续使用 定义为 。如上所述, autograd 的梯度约定是围绕实值损失函数的优化而设计的,因此让我们假设 是更大 的实值损失函数 的一部分。使用链式法则,我们可以写:
(1)¶
现在使用Wirtinger导数的定义,我们可以写成:
这里需要注意的是,由于 和 是实函数,并且 是实数,根据我们的假设 是实值函数的一部分,我们有:
(2)¶
即, 等于 .
求解上述方程对于 和 ,我们得到:
(3)¶
使用(2),我们得到:
(4)¶
最后一个等式是编写自己的梯度时的重要等式,因为它将我们的导数公式分解为一个更简单的公式,便于手工计算。
如何为复杂函数编写自己的导数公式?¶
上述方框中的公式为我们提供了复函数上所有导数的一般公式。然而,我们仍然需要计算 和 。有两种方法可以做到这一点:
第一种方法是直接使用Wirtinger导数的定义,并计算 和 通过使用 和 (你可以用正常的方式计算)。
第二种方法是使用变量替换技巧,将重写为二元函数,并通过将和视为独立变量来计算共轭Wirtinger导数。这通常更容易;例如,如果所讨论的函数是全纯的,则只会使用(并且将为零)。
让我们以函数 为例,其中 。
使用第一种方法计算Wirtinger导数,我们有。
使用(4),以及grad_output = 1.0(这是在PyTorch中对标量输出调用backward()
时使用的默认梯度输出值),我们得到:
使用第二种方法计算Wirtinger导数,我们直接得到:
再次使用(4),我们得到 。正如你所见,第二种方法涉及较少的计算,并且在更快的计算中更为方便。
保存张量的钩子¶
您可以通过定义一对 pack_hook
/ unpack_hook
钩子来控制 如何打包/解包保存的张量。pack_hook
函数应将其单个张量参数作为输入,但可以返回任何 Python 对象(例如,另一个张量、一个元组,甚至是一个包含文件名的字符串)。unpack_hook
函数将其单个参数作为 pack_hook
的输出,并应返回一个张量以用于反向传播。unpack_hook
返回的张量只需要与传递给 pack_hook
的输入张量具有相同的内容。特别是,任何与自动求导相关的元数据都可以被忽略,因为它们将在解包过程中被覆盖。
这样的一个例子是:
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook
不应删除临时文件,因为它可能会被多次调用:临时文件应保持活动状态,直到返回的 SelfDeletingTempFile 对象被销毁。在上面的示例中,我们通过在不再需要时关闭它(在删除 SelfDeletingTempFile 对象时)来防止临时文件泄漏。
注意
我们保证pack_hook
只会被调用一次,但unpack_hook
可以根据反向传播的需要被多次调用,并且我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行就地操作,因为它们可能导致意外的副作用。如果对打包钩子的输入进行了就地修改,PyTorch 将抛出错误,但不会捕获对解包钩子输入进行就地修改的情况。
为保存的张量注册钩子¶
您可以通过在一个SavedTensor
对象上调用register_hooks()
方法来注册一对钩子。这些对象作为grad_fn
的属性暴露出来,并以_raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
当配对注册时,pack_hook
方法会被立即调用。
每当需要访问保存的张量时,unpack_hook
方法会被调用,无论是通过 y.grad_fn._saved_self
还是反向传播过程中。
警告
如果你在保存的张量被释放后(即在反向传播调用后)仍然保留对一个SavedTensor
的引用,调用它的register_hooks()
是被禁止的。PyTorch大多数情况下会抛出一个错误,但在某些情况下可能无法做到这一点,可能会导致未定义的行为。
注册保存张量的默认钩子¶
或者,您可以使用上下文管理器
saved_tensors_hooks
注册一对钩子,这些钩子将应用于在该上下文中创建的所有保存的张量。
示例:
# 仅保存磁盘上大小 >= 1000 的张量
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... 计算输出
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程本地的。 因此,以下代码不会产生预期效果,因为钩子不会通过DataParallel传递。
# 示例:不要这样做
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有用于减少Tensor对象创建的优化。例如:
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
没有钩子时,x
、y.grad_fn._saved_self
和
y.grad_fn._saved_other
都指向同一个张量对象。
使用钩子时,PyTorch 会将 x 打包并解包为两个新的张量对象,
这两个对象与原始的 x 共享相同的存储(没有进行复制)。
反向钩子执行¶
本节将讨论不同钩子何时触发或不触发。
然后将讨论它们的触发顺序。
将涵盖的钩子包括:通过
torch.Tensor.register_hook()
注册到Tensor的反向钩子,通过
torch.Tensor.register_post_accumulate_grad_hook()
注册到Tensor的后累加梯度钩子,通过
torch.autograd.graph.Node.register_hook()
注册到Node的后钩子,以及通过
torch.autograd.graph.Node.register_prehook()
注册到Node的前钩子。
特定钩子是否会被触发¶
通过 torch.Tensor.register_hook()
注册到张量的钩子在计算该张量的梯度时执行。(请注意,这并不要求张量的 grad_fn 被执行。例如,如果张量作为 inputs
参数传递给 torch.autograd.grad()
,张量的 grad_fn 可能不会被执行,但注册到该张量的钩子将始终被执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到张量的钩子在为该张量累积梯度后执行,这意味着张量的 grad 字段已被设置。而通过 torch.Tensor.register_hook()
注册的钩子在计算梯度时运行,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册的钩子仅在张量的 grad 字段在反向传播结束时由 autograd 更新后触发。因此,post-accumulate-grad 钩子只能为叶子张量注册。通过 torch.Tensor.register_post_accumulate_grad_hook()
在非叶子张量上注册钩子将会出错,即使你调用 backward(retain_graph=True)。
使用
torch.autograd.graph.Node.register_hook()
或
torch.autograd.graph.Node.register_prehook()
注册到 torch.autograd.graph.Node
的钩子只有在注册到的节点被执行时才会触发。
一个特定的节点是否被执行可能取决于反向传播是否通过
torch.autograd.grad()
或 torch.autograd.backward()
调用。
具体来说,当你在一个与张量对应的节点上注册钩子时,你应该注意这些差异,该张量是你传递给 torch.autograd.grad()
或
torch.autograd.backward()
作为 inputs
参数的一部分。
如果你正在使用 torch.autograd.backward()
,所有上述提到的钩子都会被执行,
无论你是否指定了 inputs
参数。这是因为 .backward() 执行所有
节点,即使它们对应于指定为输入的张量。
(请注意,执行与作为 inputs
传递的张量对应的额外节点
通常是不必要的,但仍然会执行。此行为可能会发生变化;
你不应依赖于此。)
另一方面,如果你使用的是torch.autograd.grad()
,注册到与传递给input
的张量相对应的节点的反向钩子可能不会被执行,因为除非有另一个输入依赖于此节点的梯度结果,否则这些节点将不会被执行。
不同钩子触发的顺序¶
事情发生的顺序是:
注册到 Tensor 的钩子会被执行
注册到 Node 的 pre-hooks 会在 Node 执行时执行(如果 Node 被执行)。
对于保留梯度的张量,
.grad
字段会被更新节点被执行(受上述规则约束)
对于已经累积了
.grad
的叶子张量,后累积梯度钩子会被执行注册到节点的后置钩子会在节点执行时执行(如果节点被执行)
如果在同一个Tensor或Node上注册了多个相同类型的钩子,它们将按照注册的顺序执行。 后执行的钩子可以观察到由先前钩子对梯度所做的修改。
特殊钩子¶
torch.autograd.graph.register_multi_grad_hook()
是使用注册到张量的钩子实现的。每个单独的张量钩子按照上面定义的张量钩子顺序触发,并且在计算最后一个张量梯度时调用注册的多梯度钩子。
torch.nn.modules.module.register_module_full_backward_hook()
是使用注册到节点的钩子实现的。在计算前向传播时,钩子被注册到与模块的输入和输出相对应的 grad_fn。因为一个模块可能接受多个输入并返回多个输出,所以在前向传播之前,首先对模块的输入应用一个虚拟的自定义自动求导函数,并在模块的输出之前应用该函数,以确保这些张量共享一个 grad_fn,然后我们可以将钩子附加到该 grad_fn 上。
当张量被原地修改时张量钩子的行为¶
通常,注册到张量的钩子会接收到该张量相对于输出的梯度,其中张量的值被认为是计算反向传播时的值。
然而,如果你将钩子注册到一个张量,然后就地修改该张量,那么在就地修改之前注册的钩子同样会接收到关于该张量的输出的梯度,但该张量的值被视为其就地修改之前的值。
如果你更喜欢前一种情况下的行为, 你应该在对该张量进行所有就地修改之后再注册它们。 例如:
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解以下内容可能会有所帮助:在底层,当钩子注册到一个张量时,它们实际上会永久绑定到该张量的grad_fn上,因此如果该张量被就地修改,即使该张量现在有了一个新的grad_fn,之前注册的钩子仍将继续与旧的grad_fn关联,例如,当自动求导引擎在图中到达该张量的旧grad_fn时,这些钩子将会触发。