Shortcuts

torch.distributions.kl 的源代码

import math
import warnings
from functools import total_ordering
from typing import Callable, Dict, Tuple, Type

import torch
from torch import inf

from .bernoulli import Bernoulli
from .beta import Beta
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exp_family import ExponentialFamily
from .exponential import Exponential
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_normal import HalfNormal
from .independent import Independent
from .laplace import Laplace
from .lowrank_multivariate_normal import (
    _batch_lowrank_logdet,
    _batch_lowrank_mahalanobis,
    LowRankMultivariateNormal,
)
from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
from .poisson import Poisson
from .transformed_distribution import TransformedDistribution
from .uniform import Uniform
from .utils import _sum_rightmost, euler_constant as _euler_gamma

_KL_REGISTRY: Dict[
    Tuple[Type, Type], Callable
] = {}  # 事实来源,将一些通用的 (类型, 类型) 对映射到函数。
_KL_MEMOIZE: Dict[
    Tuple[Type, Type], Callable
] = {}  # 记忆化版本,将许多特定的 (类型, 类型) 对映射到函数。

__all__ = ["register_kl", "kl_divergence"]


[docs]def register_kl(type_p, type_q): """ 装饰器,用于向 :meth:`kl_divergence` 注册成对函数。 用法:: @register_kl(Normal, Normal) def kl_normal_normal(p, q): # 在此处插入实现 查找返回按子类排序的最具体的 (类型,类型) 匹配。如果 匹配不明确,则会引发 `RuntimeWarning`。例如,要 解决不明确的情况:: @register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ... 你应该注册第三个最具体的实现,例如:: register_kl(DerivedP, DerivedQ)(kl_version1) # 打破平局。 参数: type_p (type): :class:`~torch.distributions.Distribution` 的子类。 type_q (type): :class:`~torch.distributions.Distribution` 的子类。 """ if not isinstance(type_p, type) and issubclass(type_p, Distribution): raise TypeError( f"Expected type_p to be a Distribution subclass but got {type_p}" ) if not isinstance(type_q, type) and issubclass(type_q, Distribution): raise TypeError( f"Expected type_q to be a Distribution subclass but got {type_q}" ) def decorator(fun): _KL_REGISTRY[type_p, type_q] = fun _KL_MEMOIZE.clear() # 重置,因为查找顺序可能已更改 return fun return decorator
@total_ordering class _Match: __slots__ = ["types"] def __init__(self, *types): self.types = types def __eq__(self, other): return self.types == other.types def __le__(self, other): for x, y in zip(self.types, other.types): if not issubclass(x, y): return False if x is not y: break return True def _dispatch_kl(type_p, type_q): """ 假设单继承,找到最具体的近似匹配。 """ matches = [
优云智算