Shortcuts

torch.nn.modules.flatten 的源代码

from .module import Module

from typing import Tuple, Union
from torch import Tensor
from torch.types import _size

__all__ = ['Flatten', 'Unflatten']

[docs]class Flatten(Module): r""" 将一个连续维度的张量展平。 用于与 :class:`~nn.Sequential` 一起使用,详情请参阅 :meth:`torch.flatten`。 形状: - 输入: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' 其中 :math:`S_{i}` 是维度 :math:`i` 的大小,:math:`*` 表示任意数量的维度,包括无。 - 输出: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`。 参数: start_dim: 开始展平的第一个维度 (默认 = 1)。 end_dim: 结束展平的最后一个维度 (默认 = -1)。 示例:: >>> input = torch.randn(32, 1, 5, 5) >>> # 使用默认参数 >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # 使用非默认参数 >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5]) """ __constants__ = ['start_dim', 'end_dim'] start_dim: int end_dim: int def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: super().__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input: Tensor) -> Tensor: return input.flatten(self.start_dim, self.end_dim) def extra_repr(self) -> str: return f'start_dim={self.start_dim}, end_dim={self.end_dim}'
[docs]class Unflatten(Module): r""" 将一个张量维度扩展为所需的形状。用于与 :class:`~nn.Sequential` 一起使用。 * :attr:`dim` 指定要展平的输入张量的维度,它可以是 `int` 或 `str`,分别用于 `Tensor` 或 `NamedTensor`。 * :attr:`unflattened_size` 是张量展平维度的新形状,它可以是 `tuple` 或 `list` 或 `torch.Size` 用于 `Tensor` 输入;对于 `NamedTensor` 输入,它可以是 `NamedShape` (元组的元组)。 形状: - 输入: :math:`(*, S_{\text{dim}}, *)`,其中 :math:`S_{\text{dim}}` 是维度 :attr:`dim` 的大小,:math:`*` 表示任意数量的维度,包括无。 - 输出: :math:`(*, U_1, ..., U_n, *)`,其中 :math:`U` = :attr:`unflattened_size` 且 :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`。 参数: dim (Union[int, str]): 要展平的维度 unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): 展平维度的新形状 示例: >>> input = torch.randn(2, 50) >>> # 使用元组 >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # 使用 torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # 使用 namedshape (元组的元组) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """ NamedShape = Tuple[Tuple[str, int]] __constants__ = ['dim', 'unflattened_size'] dim: Union[int, str] unflattened_size: Union[_size, NamedShape] def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: super().__init__() if isinstance(dim, int): self._require_tuple_int(unflattened_size) elif isinstance(dim, str): self._require_tuple_tuple(unflattened_size) else: raise TypeError("invalid argument type for dim parameter") self.dim = dim self.unflattened_size = unflattened_size def _require_tuple_tuple(self, input): <span class="
优云智算