ConcatDataset¶
- class torchtune.datasets.ConcatDataset(datasets: List[Dataset])[source]¶
一个用于将多个子数据集连接成单个数据集的数据集类。该类使得能够统一处理不同的数据集,就像它们是一个单一的数据集一样,简化了诸如同时在多个数据源上训练模型等任务。
该类内部管理不同数据集的聚合,并允许在它们之间进行透明索引。 然而,它要求所有组成的数据集都必须完全加载到内存中,这对于非常大的数据集可能不是最优的。
在初始化时,此类计算所有数据集的累积长度,并维护一个索引到相应数据集的内部映射。这种方法允许
ConcatDataset在访问特定索引时透明地将数据检索委托给适当的子数据集。注意
使用此类处理非常大的数据集可能会导致高内存消耗,因为它需要将所有数据集加载到内存中。对于大规模场景,请考虑其他可能按需流式传输数据的策略。
- Parameters:
datasets (List[Dataset]) – 要连接的数据集列表。每个数据集必须是派生自
Dataset类的实例。- Raises:
ValueError – 如果 PackedDataset 的实例在 datasets 中
示例
>>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) >>> concat_dataset = ConcatDataset([dataset1, dataset2]) >>> print(len(concat_dataset)) # Total length of both datasets >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset
这也可以通过将数据集列表传递到YAML配置中来完成:
dataset: - _component_: torchtune.datasets.instruct_dataset source: vicgalle/alpaca-gpt4 split: train train_on_input: True - _component_: torchtune.datasets.instruct_dataset source: samsum column_map: {"output": "summary"} split: train train_on_input: False
该类主要专注于提供一个统一的接口来访问多个数据集中的元素,增强处理多样化数据源以训练机器学习模型的灵活性。