Shortcuts

Transformer解码器

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[源代码]

TransformerDecoder 是由 N 个解码器层组成的堆栈。

Parameters
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 类的一个实例(必需)。

  • num_layers (int) – 解码器中子解码器层的数量(必需)。

  • norm (可选[模块]) – 层归一化组件(可选)。

Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[源代码]

依次将输入(和掩码)通过解码器层。

Parameters
  • tgt (张量) – 传递给解码器的序列(必需)。

  • memory (Tensor) – 来自编码器最后一层的序列(必需)。

  • tgt_mask (可选[张量]) – 目标序列的掩码(可选)。

  • memory_mask (可选[张量]) – 用于记忆序列的掩码(可选)。

  • tgt_key_padding_mask (可选[张量]) – 每个批次的目标键的掩码(可选)。

  • memory_key_padding_mask (可选[张量]) – 每个批次内存键的掩码(可选)。

  • tgt_is_causal (可选[布尔值]) – 如果指定,则应用因果掩码作为 tgt mask。 默认值: None; 尝试检测因果掩码。 警告: tgt_is_causal 提供了一个提示,即 tgt_mask 是 因果掩码。提供错误的提示可能会导致 错误的执行,包括向前和向后 兼容性。

  • memory_is_causal (bool) – 如果指定,则应用因果掩码作为 memory mask。 默认值:False。 警告: memory_is_causal 提供了一个提示,即 memory_mask 是因果掩码。提供错误的提示可能会导致错误的执行,包括 向前和向后的兼容性问题。

Return type

张量

Shape:

查看 Transformer 文档。

优云智算