使用Lance创建多模态数据集

得益于Lance文件格式能够存储不同模态数据的能力,LanceDB在存储多模态数据集方面表现出色的重要应用场景之一。在这个简短的示例中,我们将介绍如何将多模态数据集存储为Lance文件格式。

这里选择的数据集是Flickr8k数据集。Flickr8k是一个基于句子的图像描述和搜索的基准集合,包含8,000张图片,每张图片都配有五个不同的标题,这些标题清晰地描述了突出的实体和事件。 这些图片选自六个不同的Flickr群组,通常不包含任何知名人物或地点,而是经过人工筛选以展现各种场景和情境。

我们将利用上述Flickr8k数据集创建一个图像-标题配对数据集,用于多模态模型训练,并以Lance数据集格式保存,包含图像文件名、每张图像的所有标题(保持顺序)以及图像本身(二进制格式)。

导入与设置

我们假设您已下载数据集,具体来说是"Flickr8k.token.txt"文件和"Flicker8k_Dataset/"文件夹,且两者都位于当前目录中。 这些文件可以从这里下载(请同时下载数据集和文本压缩文件)。

我们还假设您已安装pyarrow和pylance,以及opencv(用于读取图像)和tqdm(用于美观的进度条)。

现在让我们从导入和定义字幕文件及图像数据集文件夹开始。

import os
import cv2
import random

import lance
import pyarrow as pa

import matplotlib.pyplot as plt

from tqdm.auto import tqdm

captions = "Flickr8k.token.txt"
image_folder = "Flicker8k_Dataset/"

加载与处理

在flickr8k数据集中,每张图片都有多个对应的有序描述文本。 我们将把这些描述文本放入一个列表中,每个列表项对应一张图片,列表中的位置代表它们原始出现的顺序。 让我们将标注数据(图片路径和对应的描述文本)加载到一个列表中,每个列表元素是一个由图片名称、描述编号和描述文本本身组成的元组。

with open(captions, "r") as fl:
    annotations = fl.readlines()

# Converts the annotations where each element of this list is a tuple consisting of image file name, caption number and caption itself
annotations = list(map(lambda x: tuple([*x.split('\t')[0].split('#'), x.split('\t')[1]]), annotations))

现在,对于同一张图片的所有标题,我们将按顺序将它们排序并放入一个列表中。

captions = []
image_ids = set(ann[0] for ann in annotations)
for img_id in tqdm(image_ids):
    current_img_captions = []
    for ann_img_id, num, caption in annotations:
        if img_id == ann_img_id:
            current_img_captions.append((num, caption))

    # Sort by the annotation number
    current_img_captions.sort(key=lambda x: x[0])
    captions.append((img_id, tuple([x[1] for x in current_img_captions])))

转换为Lance数据集

现在我们的字幕列表已采用适当格式,我们将编写一个process()函数,该函数将接收上述字幕作为参数,并生成一个由image_id()image()captions()组成的Pyarrow记录批次。此记录批次中的图像将以二进制格式存储,而图像的所有字幕将保留其顺序存储在列表中。

def process(captions):
    for img_id, img_captions in tqdm(captions):
        try:
            with open(os.path.join(image_folder, img_id), 'rb') as im:
                binary_im = im.read()

        except FileNotFoundError:
            print(f"img_id '{img_id}' not found in the folder, skipping.")
            continue

        img_id = pa.array([img_id], type=pa.string())
        img = pa.array([binary_im], type=pa.binary())
        capt = pa.array([img_captions], pa.list_(pa.string(), -1))

        yield pa.RecordBatch.from_arrays(
            [img_id, img, capt],
            ["image_id", "image", "captions"]
        )

我们还定义相同的模式,以告知Pyarrow表中应预期的数据类型。

schema = pa.schema([
    pa.field("image_id", pa.string()),
    pa.field("image", pa.binary()),
    pa.field("captions", pa.list_(pa.string(), -1)),
])

我们包含了image_id()(即原始图像名称),以便将来更容易引用和调试。

最后,我们定义一个读取器来迭代读取这些记录批次,然后将它们写入磁盘上的lance数据集。

reader = pa.RecordBatchReader.from_batches(schema, process(captions))
lance.write_dataset(reader, "flickr8k.lance", schema)

基本上就是这样!如果你想以笔记本形式执行此操作,可以在我们的深度学习示例库中查看这个例子这里

如需更多使用Lance数据集的深度学习相关示例,请务必查看lance-deeplearning-recipes代码库!