torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks 的源代码
from typing import Any
import torch
from torch.distributed import GradBucket
__all__ = ["noop_hook"]
[docs]def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
"""
返回一个包装输入的未来,因此它是一个不产生任何通信开销的空操作。
此钩子应**仅**用于所有reduce优化的余量分析,
而不是正常的梯度同步。
例如,如果在注册此钩子后只能观察到不到10%的训练时间加速,
通常意味着在这种情况下,所有reduce不是性能瓶颈。
如果无法轻松获取GPU跟踪或跟踪分析复杂
某些因素,例如所有reduce与计算之间的重叠或跨等级的不同步,
这种工具可能特别有用。
示例::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, noop_hook)
"""
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut