torch.nn.modules.activation 的源代码
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor
from .linear import NonDynamicallyQuantizableLinear
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
[docs]class Threshold(Module):
r"""对输入张量的每个元素进行阈值处理。
阈值定义为:
.. math::
y =
\begin{cases}
x, &\text{ 如果 } x > \text{threshold} \\
\text{value}, &\text{ 否则 }
\end{cases}
参数:
threshold: 阈值
value: 替换值
inplace: 可以选择就地操作。默认值: ``False``
形状:
- 输入: :math:`(*)`,其中 :math:`*` 表示任意数量的维度。
- 输出: :math:`(*)`,与输入形状相同。
示例::
>>> m = nn.Threshold(0.1, 20)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['threshold', 'value', 'inplace']
threshold: float
value: float
inplace: bool
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
super().__init__()
self.threshold = threshold
self.value = value
self.inplace = inplace
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
def forward(self, input: Tensor) -> Tensor:
return F.threshold(input, self.threshold, self.value, self.inplace)
def extra_repr(self):
inplace_str = ', inplace=True' if self.inplace else ''
return f'threshold={self.threshold}, value={self.value}{inplace_str}'
[docs]class ReLU(Module):
r"""逐元素应用修正线性单元函数。
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
参数:
inplace: 可以选择就地操作。默认值: ``False``
形状:
- 输入: :math:`(*)`,其中 :math:`*` 表示任意数量的维度。
- 输出: :math:`(*)`,与输入形状相同。
.. image:: ../scripts/activation_images/ReLU.png
示例::
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
CReLU 的实现 - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input), m(-input)))
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super().__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.relu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
[docs]class RReLU(Module):
r"""逐元素应用随机泄漏修正线性单元函数。
该方法在论文中描述:</span