Shortcuts

视觉交叉注意力掩码

class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]

计算文本+图像输入的交叉注意力掩码。与图像标记参与交叉注意力的文本标记将在掩码中显示为True,并遵循Flamingo论文图7中展示的交错结构(https://arxiv.org/pdf/2204.14198):

  1. 紧跟在图像标记之后的文本标记,直到下一个图像标记

  2. 连续的图像标记关注后续的文本标记

     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
    <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.

结果掩码是每张图像构建的,形状为 (text_seq_len, image_seq_len), 其中 True 表示从图像编码器输出的标记在交叉注意力中关注文本序列中的标记。这些掩码的列表 返回的长度等于样本中的图像数量。

Parameters:
  • tile_size (int) – 图像变换中图像块的大小

  • patch_size (int) – 每个补丁的大小。用于将瓦片分割成补丁。 例如,对于 patch_size = 40,形状为 (400, 400) 的瓦片将有一个 10x10 的补丁网格, 每个补丁的形状为 (40, 40)。

  • image_token_id (int) – 图像特殊标记的标记ID。