PyTorch集成¶
机器学习用户可以使用LanceDataset,它是torch.utils.data.IterableDataset的子类,可直接在PyTorch训练和推理循环中使用Lance数据。
首先需要创建一个用于训练的机器学习数据集。借助Lance ❤️ HuggingFace,只需一行Python代码即可将HuggingFace数据集转换为Lance数据集。
# Huggingface datasets
import datasets
import lance
hf_ds = datasets.load_dataset(
"poloclub/diffusiondb",
split="train",
# name="2m_first_1k", # for a smaller subset of the dataset
)
lance.write_dataset(hf_ds, "diffusiondb_train.lance")
然后,您可以在PyTorch训练和推理循环中使用Lance数据集。
注意:
PyTorch数据集会自动将数据转换为
torch.Tensor
2. lance不支持多进程fork操作。如果使用多进程,请改用spawn方式。 安全的数据加载器采用了spawn方法。
不安全的数据加载器
import torch
import lance.torch.data
# Load lance dataset into a PyTorch IterableDataset.
# with only columns "image" and "prompt".
dataset = lance.torch.data.LanceDataset(
"diffusiondb_train.lance",
columns=["image", "prompt"],
batch_size=128,
batch_readahead=8, # Control multi-threading reads.
)
# Create a PyTorch DataLoader
dataloader = torch.utils.data.DataLoader(dataset)
# Inference loop
for batch in dataloader:
inputs, targets = batch["prompt"], batch["image"]
outputs = model(inputs)
...
安全数据加载器
from lance.torch.data import SafeLanceDataset, get_safe_loader
dataset = SafeLanceDataset(temp_lance_dataset)
# use spawn method to avoid fork-safe issue
loader = get_safe_loader(
dataset,
num_workers=2,
batch_size=16,
drop_last=False,
)
total_samples = 0
for batch in loader:
total_samples += batch["id"].shape[0]
LanceDataset 可以与 Sampler 类组合使用
来控制采样策略。例如,您可以使用 ShardedFragmentSampler
在分布式训练环境中使用它。如果未指定,则默认为全表扫描。
from lance.sampler import ShardedFragmentSampler
from lance.torch.data import LanceDataset
# Load lance dataset into a PyTorch IterableDataset.
# with only columns "image" and "prompt".
dataset = LanceDataset(
"diffusiondb_train.lance",
columns=["image", "prompt"],
batch_size=128,
batch_readahead=8, # Control multi-threading reads.
sampler=ShardedFragmentSampler(
rank=1, # Rank of the current process
world_size=8, # Total number of processes
),
)
可用的采样器:
lance.sampler.ShardedFragmentSamplerlance.sampler.ShardedBatchSampler
警告
在多进程环境下,您可能不应使用fork方式,因为lance内部采用多线程机制,而fork与多线程的配合效果不佳。具体可参阅此讨论。