• Docs >
  • torch.utils.checkpoint
Shortcuts

torch.utils.checkpoint

注意

检查点机制通过在反向传播期间为每个检查点段重新运行前向传递段来实现。这可能导致像RNG状态这样的持久状态比没有检查点时更先进。默认情况下,检查点包括处理RNG状态的逻辑,使得使用RNG的检查点段(例如通过dropout)与非检查点段相比具有确定性的输出。存储和恢复RNG状态的逻辑可能会根据检查点操作的运行时间产生适度的性能影响。如果不需要与非检查点段相比的确定性输出,请提供preserve_rng_state=Falsecheckpointcheckpoint_sequential,以在每次检查点期间省略存储和恢复RNG状态。

存储逻辑保存并恢复了CPU和另一种设备类型(通过_infer_device_type推断设备类型,排除CPU张量)的RNG状态到run_fn。如果有多个设备,设备状态只会为单一设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致不正确的梯度。(请注意,如果检测到CUDA设备,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有CPU张量,默认设备类型状态(默认值为cuda,可以通过DefaultDeviceType设置为其他设备)将被保存和恢复。然而,逻辑无法预见用户是否会在run_fn内部将张量移动到新设备。因此,如果您在run_fn内部将张量移动到新设备(“新”意味着不属于[当前设备+张量参数设备]的集合),则无法保证与非检查点传递相比的确定性输出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[源代码]

检查点一个模型或模型的一部分。

激活检查点是一种以计算换取内存的技术。 与在反向传播期间将用于反向传播的张量保持活动状态直到它们在梯度计算中使用不同,检查点区域中的前向计算在反向传播期间省略保存张量并重新计算它们。激活检查点可以应用于模型的任何部分。

目前有两种检查点实现方式可用,由use_reentrant参数决定。建议您使用use_reentrant=False。请参阅下面的注释以了解它们的区别。

警告

如果在反向传播过程中,function 的调用与前向传播不同,例如由于全局变量,检查点版本可能不等价,可能会导致错误被引发或导致梯度计算不正确。

警告

use_reentrant 参数应明确传递。在版本 2.4 中,如果我们没有传递 use_reentrant,将会引发异常。 如果您使用的是 use_reentrant=True 变体,请参阅下面的注释,了解重要的注意事项和潜在的限制。

注意

可重入的检查点变体(use_reentrant=True)和 非可重入的检查点变体(use_reentrant=False) 在以下方面有所不同:

  • 非重入式检查点在所有需要的中间激活值重新计算完成后立即停止重新计算。此功能默认启用,但可以通过set_checkpoint_early_stop()禁用。重入式检查点在反向传播过程中总是重新计算function的全部内容。

  • 可重入变体在正向传递期间不会记录自动梯度图,因为它在torch.no_grad()下运行正向传递。非可重入版本会记录自动梯度图,允许在检查点区域内对图进行反向传播。

  • 可重入检查点仅支持不带inputs参数的反向传播的 torch.autograd.backward() API,而不可重入版本支持所有执行反向传播的方式。

  • 至少一个输入和输出必须有 requires_grad=True 才能使用可重入变体。如果未满足此条件,模型的检查点部分将不会具有梯度。非可重入版本没有此要求。

  • 可重入版本不认为嵌套结构中的张量(例如,自定义对象、列表、字典等)参与自动求导,而不可重入版本则认为它们参与。

  • 可重入检查点不支持包含从计算图中分离的张量的检查点区域,而不可重入版本则支持。对于可重入变体,如果检查点段包含使用 detach()torch.no_grad() 分离的张量,反向传播将引发错误。这是因为 checkpoint 使所有输出都需要梯度,当张量在模型中被定义为不需要梯度时,这会导致问题。为避免这种情况,请在 checkpoint 函数外部分离张量。

Parameters
  • 函数 – 描述在模型的前向传递或模型的一部分中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递(activation, hidden)函数应该正确地使用第一个输入作为activation,第二个输入作为hidden

  • preserve_rng_state (bool, 可选) – 在每次检查点期间省略保存和恢复RNG状态。请注意,在torch.compile下,此标志不起作用,我们始终保留RNG状态。 默认值: True

  • use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在版本2.4中,如果我们没有传递use_reentrant,将会引发异常。如果use_reentrant=Falsecheckpoint将使用不需要可重入自动微分的实现。这允许checkpoint支持额外的功能,例如按预期与torch.autograd.grad一起工作,并支持将关键字参数输入到检查点函数中。

  • context_fn (可调用对象, 可选) – 一个返回两个上下文管理器元组的可调用对象。函数及其重新计算将分别在第一个和第二个上下文管理器下运行。此参数仅在 use_reentrant=False 时支持。

  • determinism_check (str, 可选) – 一个字符串,指定要执行的确定性检查。默认情况下,它设置为 "default", 该设置会比较重新计算的张量与保存的张量的形状、数据类型和设备。要关闭此检查,请指定 "none"。目前仅支持这两个值。 如果您希望看到更多的确定性检查,请提出问题。此参数仅在 use_reentrant=False 时支持, 如果 use_reentrant=True,则始终禁用确定性检查。

  • 调试 (布尔值, 可选) – 如果 True,错误消息还将包括 在原始前向计算期间运行的操作符的跟踪 以及重新计算。此参数仅在 use_reentrant=False 时支持。

  • args – 包含输入到 function 的元组

Returns

运行 function*args 上的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[源代码]

检查点一个顺序模型以节省内存。

顺序模型按顺序执行一系列模块/函数。因此,我们可以将这样的模型分成多个段,并对每个段进行检查点操作。除了最后一个段之外,所有段都不会存储中间激活值。每个检查点段的输入将被保存,以便在反向传播过程中重新运行该段。

警告

应显式传递use_reentrant参数。在2.4版本中,如果我们没有传递use_reentrant,将会引发异常。 如果您使用的是use_reentrant=True` 变体, 请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解 此变体的重要注意事项和限制。 建议您使用 ``use_reentrant=False

Parameters
  • 函数 – 一个 torch.nn.Sequential 或模块或函数的列表(构成模型)按顺序运行。

  • segments – 模型中要创建的块数

  • 输入 – 输入到 函数 的张量

  • preserve_rng_state (bool, 可选) – 在每个检查点期间省略保存和恢复RNG状态。 默认值: True

  • use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在版本2.4中,如果我们没有传递use_reentrant,我们将引发异常。如果use_reentrant=Falsecheckpoint将使用不需要可重入自动微分的实现。这允许checkpoint支持额外的功能,例如按预期与torch.autograd.grad一起工作,并支持将关键字参数输入到检查点函数中。

Returns

按顺序运行 functions*inputs 上的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[源代码]

上下文管理器,用于设置在运行时是否应打印额外的调试信息。有关更多信息,请参阅 checkpoint()debug 标志。请注意,当设置此上下文管理器时,它会覆盖传递给 checkpoint 的 debug 值。要推迟到本地设置,请将 None 传递给此上下文。

Parameters

enabled (bool) – 是否应打印调试信息。 默认值为‘None’。