视觉交叉注意力掩码¶
- class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]¶
计算文本+图像输入的交叉注意力掩码。与图像标记参与交叉注意力的文本标记将在掩码中显示为True,并遵循Flamingo论文图7中展示的交错结构(https://arxiv.org/pdf/2204.14198):
紧跟在图像标记之后的文本标记,直到下一个图像标记
连续的图像标记关注后续的文本标记
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ <img1> <img2>These are two dogs. <img3> This is a cat.结果掩码是每张图像构建的,形状为 (text_seq_len, image_seq_len), 其中 True 表示从图像编码器输出的标记在交叉注意力中关注文本序列中的标记。这些掩码的列表 返回的长度等于样本中的图像数量。