Shortcuts

消息

消息是torchtune中的一个核心组件,负责管理文本和多模态内容的标记化。它作为所有标记器和数据集API操作的通用接口。消息包含有关文本内容的信息,发送文本内容的角色,以及模型标记器中特殊标记相关的其他信息。有关消息的各个参数的更多信息,请参阅消息的API参考。

创建消息

消息可以通过标准类构造函数创建,也可以直接从字典创建。

from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
# This is identical
msg = Message.from_dict(
    {
        "role": "user",
        "content": "Hello world!",
        "masked": True,
        "eot": True,
        "ipython": False,
    },
)
print(msg.content)
# [{'type': 'text', 'content': 'Hello world!'}]

内容被格式化为字典列表。这是因为消息还可以包含多模态内容,例如图像。

消息中的图片

对于多模态数据集,您需要将图像作为Image添加到相应的消息中。要将其添加到消息的开头,只需将其前置到内容列表中。

import PIL
from torchtune.data import Message

img_msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": PIL.Image.new(mode="RGB", size=(4, 4)),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)

这将指示模型分词器在何处添加图像特殊标记,并由模型转换器适当处理。

在许多情况下,您将拥有一个图像路径而不是原始的Image。您可以使用load_image()工具来处理本地路径和远程路径。

import PIL
from torchtune.data import Message, load_image

image_path = "path/to/image.jpg"
img_msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": load_image(image_path),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)

如果你的数据集包含图像标签,或占位符文本以指示图像应插入文本中的位置, 你可以使用format_content_with_images()将文本拆分为正确的内容列表, 然后可以将其传递到Message的内容字段中。

import PIL
from torchtune.data import format_content_with_images

content = format_content_with_images(
    "<|image|>hello <|image|>world",
    image_tag="<|image|>",
    images=[PIL.Image.new(mode="RGB", size=(4, 4)), PIL.Image.new(mode="RGB", size=(4, 4))]
)
print(content)
# [
#     {"type": "image", "content": <PIL.Image.Image>},
#     {"type": "text", "content": "hello "},
#     {"type": "image", "content": <PIL.Image.Image>},
#     {"type": "text", "content": "world"}
# ]

消息转换

消息转换是将原始数据格式化为一系列torchtune 消息对象的便捷工具。

from torchtune.data import InputOutputToMessages

sample = {
    "input": "What is your name?",
    "output": "I am an AI assistant, I don't have a name."
}
transform = InputOutputToMessages()
output = transform(sample)
for message in output["messages"]:
    print(message.role, message.text_content)
# user What is your name?
# assistant I am an AI assistant, I don't have a name.

请参阅消息转换以获取更多讨论。

使用提示模板格式化消息

提示模板提供了一种将消息格式化为结构化文本模板的方式。您可以简单地调用任何继承自PromptTemplateInterface的类,并将其应用于消息列表,它将在内容列表中添加适当的文本。

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].content)
# [{'type': 'text', 'content': '[INST] '},
# {'type': 'text', 'content': 'Hello world!'},
# {'type': 'text', 'content': ' [/INST] '}]

访问消息中的文本内容

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].text_content)
# [INST] Hello world! [/INST]

访问消息中的图像

from torchtune.data import Message
import PIL

msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": PIL.Image.new(mode="RGB", size=(4, 4)),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)
if msg.contains_media:
    print(msg.get_media())
# [<PIL.Image.Image image mode=RGB size=4x4 at 0x7F8D27E72740>]

标记化消息

所有模型的分词器都有一个tokenize_messsages方法,该方法将消息对象列表转换为令牌ID和损失掩码。

from torchtune.models.mistral import mistral_tokenizer
from torchtune.data import Message

m_tokenizer = mistral_tokenizer(
    path="/tmp/Mistral-7B-v0.1/tokenizer.model",
    prompt_template="torchtune.models.mistral.MistralChatTemplate",
    max_seq_len=8192,
)
msgs = [
    Message(
        role="user",
        content="Hello world!",
        masked=True,
        eot=True,
        ipython=False,
    ),
    Message(
        role="assistant",
        content="Hi, I am an AI assistant.",
        masked=False,
        eot=True,
        ipython=False,
    )
]
tokens, mask = m_tokenizer.tokenize_messages(msgs)
print(tokens)
# [1, 733, 16289, 28793, 22557, 1526, 28808, 28705, 733, 28748, 16289, 28793, 15359, 28725, 315, 837, 396, 16107, 13892, 28723, 2]
print(mask)  # User message is masked from the loss
# [True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False]
print(m_tokenizer.decode(tokens))
# [INST] Hello world!  [/INST] Hi, I am an AI assistant.