torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook 的源代码
from collections import defaultdict
import logging
import math
from typing import Dict
import torch
import torch.distributed as dist
from . import default_hooks as default
from torch.distributed import distributed_c10d
__all__ = [
"PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"
]
logger = logging.getLogger(__name__)
def _orthogonalize(matrices, epsilon=0):
"""
决定使用Gram-Schmidt还是QR分解来正交化一批矩阵。
QR分解在半精度下不起作用,但通常在秩 > 2 时更快。
"""
assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
num_matrices = matrices.shape[0]
rank = matrices.shape[2]
dtype = matrices.dtype
if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
_orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
else:
torch.linalg.qr(
matrices,
out=(
matrices,
torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype)
)
)
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
"""
应用Gram-Schmidt过程来正交化一批矩阵。
如果epsilon为0,这等价于`torch.qr(matrices, out=(matrices, _))`,
"""
num_cols = matrices.shape[2]
for i in range(num_cols):
# 归一化第i列。
col = matrices[:, :, i : i + 1]
# 如果没有在这里添加epsilon,可能会因梯度消失而导致除以零。
# 如果输入的矩阵批次覆盖了神经网络中至少一层的梯度,则不需要这个epsilon。
if epsilon == 0:
# 注意,如果使用FP16,col ** 2可能会下溢/上溢。
# 可能需要考虑乘以一个缩放因子并在之后除以它,或者使用bfloat16。
try:
col /= torch.norm(col, dim=1, keepdim=True)
except ZeroDivisionError:
logger.error(
"要正交化的矩阵至少有一列全为0。请在PowerSGD状态中设置一个小的值,例如1e-8作为`orthogonalization_epsilon`。"
)
# 将NaN恢复为0。
col.fill_(0.0)
else:
col /= torch.norm(col, dim=1, keepdim=True) + epsilon
# 将其投影到其余部分并移除。
if i + 1 < num_cols:
rest = matrices[:, :, i + 1 :]
rest -= torch.sum(col * rest, dim=1, keepdim=True) * col
def _should_compress(
num_rows, num_cols, matrix_approximation_rank, min_compression_rate
):
"""
推荐是否值得压缩给定的张量。
返回一个推荐,即是否值得压缩由参数描述的2D张量,
包括描述压缩预期节省的统计信息。我们认为一个张量值得
压缩的条件是 ``min_compression_rate`` < 未压缩大小 / 压缩大小,其中
未压缩大小 = ``num_rows`` * ``num_cols``,
压缩大小 = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``。
此函数的结果是一个元组,形式为 (compression_recommendation, uncompressed_el_count, compressed_el_count),其中:
compression_recommendation 为真,如果张量值得压缩,否则为假(见上文);
uncompressed_el_count 是未压缩的元素计数,即 ``num_rows`` * ``num_cols``;
compress_el_count 是压缩后的元素计数,即 (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``。
""" # noqa: B950
uncompressed_size = num_rows * num_cols
compressed_size = (num_rows + num_cols) * matrix_approximation_rank
return (
compressed_size * min_compression_rate < uncompressed_size,
uncompressed_size,
compressed_size,
)
def _report_compression_stats(bucket, state):
"""在PowerSGD状态中指定的``compression_stats_logging_frequency``频率下报告压缩统计信息。"""
if (
bucket.is_last()
and state</