Shortcuts

torch.nn.modules.channelshuffle 的源代码

from .module import Module
from .. import functional as F

from torch import Tensor

__all__ = ['ChannelShuffle']

[docs]class ChannelShuffle(Module): r"""将张量中的通道进行划分和重新排列。 此操作将形状为 :math:`(*, C , H, W)` 的张量中的通道 划分为 g 组,并将其重新排列为 :math:`(*, \frac{C}{g}, g, H, W)`, 同时保持原始张量形状。 参数: groups (int): 划分通道的组数。 示例:: >>> # xdoctest: +IGNORE_WANT("FIXME: incorrect want") >>> channel_shuffle = nn.ChannelShuffle(2) >>> input = torch.randn(1, 4, 2, 2) >>> print(input) [[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]], [[13, 14], [15, 16]], ]] >>> output = channel_shuffle(input) >>> print(output) [[[[1, 2], [3, 4]], [[9, 10], [11, 12]], [[5, 6], [7, 8]], [[13, 14], [15, 16]], ]] """ __constants__ = ['groups'] groups: int def __init__(self, groups: int) -> None: super().__init__() self.groups = groups def forward(self, input: Tensor) -> Tensor: return F.channel_shuffle(input, self.groups) def extra_repr(self) -> str: return f'groups={self.groups}'
优云智算