torch.nn.modules.adaptive 的源代码
from collections import namedtuple
import torch
from torch import Tensor
from typing import List, Sequence
from . import Sequential, ModuleList, Linear
from .module import Module
from ..functional import log_softmax
__all__ = ['AdaptiveLogSoftmaxWithLoss']
_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])
[docs]class AdaptiveLogSoftmaxWithLoss(Module):
r"""高效的softmax近似。
如描述中所述
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
Moustapha Cissé, David Grangier, and Hervé Jégou
`__。
自适应softmax是一种用于训练具有大输出空间的模型的近似策略。当标签分布高度不平衡时,它最为有效,例如在自然语言建模中,单词频率分布近似遵循`Zipf's law`_。
自适应softmax根据标签的频率将标签划分为几个簇。这些簇可能包含不同数量的目标。
此外,包含较少频率标签的簇为这些标签分配较低维度的嵌入,从而加快了计算速度。
对于每个小批量,仅评估至少有一个目标存在的簇。
其思想是,频繁访问的簇(如包含最频繁标签的第一个簇)也应该计算成本低廉——即包含少量分配的标签。
我们强烈建议查看原始论文以获取更多详细信息。
* :attr:`cutoffs` 应该是一个按递增顺序排序的整数序列。
它控制簇的数量以及将目标分配到簇的方式。例如,设置 ``cutoffs = [10, 100, 1000]``
意味着前`10`个目标将被分配到自适应softmax的“头部”,目标`11, 12, ..., 100`将被分配到第一个簇,目标`101, 102, ..., 1000`将被分配到第二个簇,而目标
`1001, 1002, ..., n_classes - 1` 将被分配到最后一个,第三个簇。
* :attr:`div_value` 用于计算每个附加簇的大小,其计算公式为
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
其中 :math:`idx` 是簇的索引(较少频率单词的簇具有较大的索引,索引从 :math:`1` 开始)。
* :attr:`head_bias` 如果设置为True,则在自适应softmax的“头部”添加一个偏置项。有关详细信息,请参阅论文。在官方实现中设置为False。
.. 警告::
传递给此模块的标签应根据其频率进行排序。这意味着最频繁的标签应由索引`0`表示,而最不频繁的标签应由索引`n_classes - 1`表示。
.. 注意::
此模块返回一个``NamedTuple``,包含``output``和``loss``字段。有关详细信息,请参阅进一步的文档。
.. 注意::
要计算所有类别的对数概率,可以使用``log_prob``方法。
参数:
in_features (int): 输入张量中的特征数量
n_classes (int): 数据集中的类别数量
cutoffs (Sequence): 用于将目标分配到其桶的截止值
div_value (float, optional): 用于计算簇大小的指数值。默认值: 4.0
head_bias (bool, optional): 如果为``True``,则在自适应softmax的“头部”添加一个偏置项。默认值: ``False``
返回:
``NamedTuple``,包含``output``和``loss``字段:
* **output** 是一个大小为``N``的张量,包含每个示例计算的目标对数概率
* **loss** 是一个标量,表示计算的负对数似然损失
形状:
- input: :math:`(N, \texttt{in\_features})` 或 :math:`(\texttt{in\_features})`
- target: :math:`(N)` 或 :math:`()`,其中每个值满足 :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
- output1: :math:`(N)` 或 :math:`()`
- output2: ``Scalar``
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
"""
in_features: int
n_classes: int
cutoffs: List[int]
div_value: float
head_bias: bool
head: Linear
tail: ModuleList
def __init__(
self,
in_features: int,
n_classes: int,
cutoffs: Sequence[int],
div_value: float = 4.,
head_bias: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
cutoffs = list(cutoffs)
if (len(cutoffs) == 0):
raise ValueError("cutoffs should be a sequence of length larger than 0")
if (cutoffs != sorted(cutoffs)) \
or (min(cutoffs) <= 0) \
or (max(cutoffs) > (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \
or any(int(c) != c for c in cutoffs):
raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1")
self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.div_value = div_value
self.head_bias = head_bias
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1