from .module import Module
from .. import functional as F
from torch import Tensor
from typing import Optional
from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t
__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d']
[docs]class Upsample(Module):
r"""对给定的多通道1D(时间)、2D(空间)或3D(体积)数据进行上采样。
输入数据的形式假设为
`minibatch x channels x [可选深度] x [可选高度] x 宽度`。
因此,对于空间输入,我们期望一个4D张量,对于体积输入,我们期望一个5D张量。
可用的上采样算法包括最近邻和线性,
双线性、双三次和三线性分别用于3D、4D和5D输入张量。
用户可以提供:attr:`scale_factor`或目标输出:attr:`size`来
计算输出大小。(不能同时提供两者,因为这是模棱两可的)
参数:
size (int 或 Tuple[int] 或 Tuple[int, int] 或 Tuple[int, int, int], 可选):
输出空间大小
scale_factor (float 或 Tuple[float] 或 Tuple[float, float] 或 Tuple[float, float, float], 可选):
空间大小的乘数。如果它是一个元组,则必须与输入大小匹配。
mode (str, 可选): 上采样算法: 可以是 ``'nearest'``、
``'linear'``、``'bilinear'``、``'bicubic'`` 和 ``'trilinear'`` 之一。
默认: ``'nearest'``
align_corners (bool, 可选): 如果为 ``True``,输入
和输出张量的角像素对齐,从而保留这些像素的值。
这仅在:attr:`mode`为
``'linear'``、``'bilinear'``、``'bicubic'`` 或 ``'trilinear'`` 时有效。
默认: ``False``
recompute_scale_factor (bool, 可选): 重新计算用于插值计算的scale_factor。如果 `recompute_scale_factor` 为 ``True``,则
`scale_factor` 必须传入,并且 `scale_factor` 用于计算
输出 `size`。计算出的输出 `size` 将用于推断新的插值比例。
请注意,当 `scale_factor` 为浮点数时,由于舍入和精度问题,它可能与重新计算的 `scale_factor` 不同。
如果 `recompute_scale_factor` 为 ``False``,则 `size` 或 `scale_factor` 将
直接用于插值。
形状:
- 输入: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` 或 :math:`(N, C, D_{in}, H_{in}, W_{in})`
- 输出: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
或 :math:`(N, C, D_{out}, H_{out}, W_{out})`,其中
.. math::
D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
.. math::
H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
.. math::
W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
.. 警告::
当 ``align_corners = True`` 时,线性插值模式
(`linear`, `bilinear`, `bicubic`, 和 `trilinear`) 不会按比例
对齐输出和输入像素,因此输出值可能会依赖
于输入大小。在0.3.1版本之前,这些模式默认行为是
``align_corners = False``。请参阅下面的具体示例,了解这
如何影响输出。
.. 注意::
如果你想进行下采样/一般调整大小,你应该使用 :func:`~nn.functional.interpolate`。
示例::
>>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
>>> input
tensor([[[[1., 2.],
[3., 4.]]]])
>>> m = nn.Upsample(scale_factor=2, mode='nearest')
>>> m(input)
tensor([[[[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]]]])
>>> # xdoctest: +IGNORE_WANT("其他测试似乎会修改打印样式")
>>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
>>> m(input)
tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000]]]])
>>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>>> m(input)
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])
>>> # 尝试在更大的张量中缩放相同的数据
>>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
>>> input_3x3[:, :, :2, :2].copy_(input)
tensor([[[[1., 2.],
[3., 4.]]]])
>>> input_3x3
tensor([[[[1., 2., 0.],
[3., 4., 0.],
[0., 0., 0.]]]])
>>> # xdoctest: +IGNORE_WANT("似乎在同一会话中运行其他测试时失败")
>>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
>>> # 注意左上角值与小输入相同(除了边界)
>>> m(input_3x3)
tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
[1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
[2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
[2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
[0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
>>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>>> # 注意左上角值现在已更改
>>> m(input_3x3)
tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
[1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
[2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
[2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
[1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
"""
__constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor']
name: str
size: Optional</span