Shortcuts

FusionEmbedding

class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[source]

融合嵌入支持在保持原始嵌入冻结的同时训练额外的特殊标记。当将新模型与语言模型融合时,可能需要一些额外的标记来支持融合的语言模型。例如,添加视觉编码器可能需要额外的标记,如<|image|>,以指示图像在文本中的位置,并需要学习该标记的嵌入。FusionEmbedding在保持原始嵌入冻结的同时,为额外的标记学习一个更小的第二嵌入。在前向传播过程中,该模块将标记路由到适当的嵌入表。

在你的模型中使用这个作为torch.nn.Embedding的直接替代品。

示例

>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128)
>>> model = TransformerDecoder(tok_embeddings=embedding, ...)
>>>
>>> # Original model state_dict still works
>>> model.load_state_dict(..., strict=False)

注意

本模块假设范围 [0, vocab_size) 内的所有标记都是原始嵌入表的一部分,而范围 [vocab_size, vocab_size + fusion_vocab_size) 内的所有新标记

Parameters:
  • vocab_size (int) – 语言模型的词汇表大小

  • fusion_vocab_size (int) – 融合模型的额外标记

  • embed_dim (int) – 两个嵌入表的嵌入维度

forward(input: Tensor) Tensor[source]
Parameters:

输入 (torch.Tensor) – 输入整数张量,形状为 [batch_size x seq_length]

Returns:

输出张量嵌入的形状为

[batch_size x seq_length x embed_dim]`

Return type:

张量

fusion_params() List[str][source]

返回融合嵌入参数。