Shortcuts

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook 的源代码

from collections import defaultdict
import logging
import math
from typing import Dict

import torch
import torch.distributed as dist

from . import default_hooks as default
from torch.distributed import distributed_c10d

__all__ = [
    "PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"
]

logger = logging.getLogger(__name__)


def _orthogonalize(matrices, epsilon=0):
    """
    决定使用Gram-Schmidt还是QR分解来正交化一批矩阵。

    QR分解在半精度下不起作用,但通常在秩 > 2 时更快。
    """
    assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]

    num_matrices = matrices.shape[0]
    rank = matrices.shape[2]
    dtype = matrices.dtype
    if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
        _orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
    else:
        torch.linalg.qr(
            matrices,
            out=(
                matrices,
                torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype)
            )
        )

def _orthogonalize_gram_schmidt(matrices, epsilon=0):
    """
    应用Gram-Schmidt过程来正交化一批矩阵。

    如果epsilon为0,这等价于`torch.qr(matrices, out=(matrices, _))`,
    """
    num_cols = matrices.shape[2]
    for i in range(num_cols):
        # 归一化第i列。
        col = matrices[:, :, i : i + 1]
        # 如果没有在这里添加epsilon,可能会因梯度消失而导致除以零。
        # 如果输入的矩阵批次覆盖了神经网络中至少一层的梯度,则不需要这个epsilon。
        if epsilon == 0:
            # 注意,如果使用FP16,col ** 2可能会下溢/上溢。
            # 可能需要考虑乘以一个缩放因子并在之后除以它,或者使用bfloat16。
            try:
                col /= torch.norm(col, dim=1, keepdim=True)
            except ZeroDivisionError:
                logger.error(
                    "要正交化的矩阵至少有一列全为0。请在PowerSGD状态中设置一个小的值,例如1e-8作为`orthogonalization_epsilon`。"
                )
                # 将NaN恢复为0。
                col.fill_(0.0)
        else:
            col /= torch.norm(col, dim=1, keepdim=True) + epsilon
        # 将其投影到其余部分并移除。
        if i + 1 < num_cols:
            rest = matrices[:, :, i + 1 :]
            rest -= torch.sum(col * rest, dim=1, keepdim=True) * col


def _should_compress(
    num_rows, num_cols, matrix_approximation_rank, min_compression_rate
):
    """
    推荐是否值得压缩给定的张量。

    返回一个推荐,即是否值得压缩由参数描述的2D张量,
    包括描述压缩预期节省的统计信息。我们认为一个张量值得
    压缩的条件是 ``min_compression_rate`` < 未压缩大小 / 压缩大小,其中
    未压缩大小 = ``num_rows`` * ``num_cols``,
    压缩大小 = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``。

    此函数的结果是一个元组,形式为 (compression_recommendation, uncompressed_el_count, compressed_el_count),其中:

    compression_recommendation 为真,如果张量值得压缩,否则为假(见上文);

    uncompressed_el_count 是未压缩的元素计数,即 ``num_rows`` * ``num_cols``;

    compress_el_count 是压缩后的元素计数,即 (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``。
    """  # noqa: B950
    uncompressed_size = num_rows * num_cols
    compressed_size = (num_rows + num_cols) * matrix_approximation_rank
    return (
        compressed_size * min_compression_rate < uncompressed_size,
        uncompressed_size,
        compressed_size,
    )


def _report_compression_stats(bucket, state):
    """在PowerSGD状态中指定的``compression_stats_logging_frequency``频率下报告压缩统计信息。"""
    if (
        bucket.is_last()
        and state</