使用Lance为LLM训练创建文本数据集¶
Lance可用于创建和缓存文本(或代码)数据集,用于大型语言模型的预训练/微调。当需要在数据子集上训练模型或分块处理数据而无需一次性将所有数据下载到磁盘时,就会出现这种需求。当您只需要处理TB或PB级数据集的子集时,这就成为一个相当棘手的问题。
在这个示例中,我们将通过分块下载文本数据集、进行分词处理并保存为Lance数据集来绕过这个问题。 您可以根据需要处理任意数量的数据样本,平均内存消耗约为3-4GB!
在本示例中,我们使用的是wikitext数据集,该数据集包含从维基百科上经过验证的优秀和特色文章中提取的超过1亿个标记。
准备和预处理原始数据集¶
首先定义数据集和分词器
import lance
import pyarrow as pa
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm.auto import tqdm # optional for progress tracking
tokenizer = AutoTokenizer.from_pretrained('gpt2')
dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', streaming=True)['train']
dataset = dataset.shuffle(seed=1337)
load_dataset中的streaming参数尤为重要,因为如果不将其设置为True,datasets库会先下载整个数据集,即使你只想使用其中的一个子集。将streaming设为True后,样本将在需要时才会被下载。
现在我们将定义一个函数来帮助我们逐个对样本进行分词。
def tokenize(sample, field='text'):
return tokenizer(sample[field])['input_ids']
该函数将从huggingface数据集中接收一个样本,并对field列中的值进行分词处理。这是您需要进行分词的主要文本内容。
创建Lance数据集¶
现在我们已经设置了原始数据集和预处理代码,让我们定义主函数,该函数接收数据集、样本数量和字段,并返回一个pyarrow批次,稍后将被写入lance数据集。
def process_samples(dataset, num_samples=100_000, field='text'):
current_sample = 0
for sample in tqdm(dataset, total=num_samples):
# If we have added all 5M samples, stop
if current_sample == num_samples:
break
if not sample[field]:
continue
# Tokenize the current sample
tokenized_sample = tokenize(sample, field)
# Increment the counter
current_sample += 1
# Yield a PyArrow RecordBatch
yield pa.RecordBatch.from_arrays(
[tokenized_sample],
names=["input_ids"]
)
该函数将遍历huggingface数据集,每次处理一个样本,对样本进行分词并生成一个包含所有词元的pyarrow RecordBatch。我们将持续这个过程,直到达到num_samples指定的样本数量或数据集末尾(以先到者为准)。
请注意,这里的"sample"指的是原始数据集中的一个样本(行)。具体一个样本的含义取决于数据集本身,它可能是一行文本或整个文本文件。在本示例中,样本长度从单行文本到段落文本不等。
我们还需要定义一个模式来告诉Lance表中预期的数据类型。由于我们的数据集仅包含长整型的token,int64是最合适的数据类型。
schema = pa.schema([
pa.field("input_ids", pa.int64())
])
最后,我们需要定义一个读取器,它将从我们的process_samples()
函数中读取记录批次流,该函数会生成由单独标记化样本组成的记录批次。
reader = pa.RecordBatchReader.from_batches(
schema,
process_samples(dataset, num_samples=500_000, field='text') # For 500K samples
)
最后我们使用lance.write_dataset()
将数据集写入磁盘。
# Write the dataset to disk
lance.write_dataset(
reader,
"wikitext_500K.lance",
schema
)
如果您想在将令牌保存到磁盘之前应用其他预处理(如掩码等),可以在process_samples函数中添加。
就这样!您的数据集已经完成分词并保存到磁盘!