Tensorflow 集成

Lance可以作为常规的tf.data.DatasetTensorflow中使用。

警告

此功能处于实验阶段,API接口未来可能会发生变化。

从Lance读取数据

使用lance.tf.data.from_lance(),您可以轻松创建一个tf.data.Dataset

import tensorflow as tf
import lance

# Create tf dataset
ds = lance.tf.data.from_lance("s3://my-bucket/my-dataset")

# Chain tf dataset with other tf primitives

for batch in ds.shuffling(32).map(lambda x: tf.io.decode_png(x["image"])):
    print(batch)

基于Lance 列式存储格式,使用lance.tf.data.from_lance()支持高效的列选择、过滤等功能。

ds = lance.tf.data.from_lance(
    "s3://my-bucket/my-dataset",
    columns=["image", "label"],
    filter="split = 'train' AND collected_time > timestamp '2020-01-01'",
    batch_size=256)

默认情况下,Lance会从投影列中推断Tensor规范。您也可以手动指定tf.TensorSpec

batch_size = 256
ds = lance.tf.data.from_lance(
    "s3://my-bucket/my-dataset",
    columns=["image", "labels"],
    batch_size=batch_size,
    output_signature={
        "image": tf.TensorSpec(shape=(), dtype=tf.string),
        "labels": tf.RaggedTensorSpec(
            dtype=tf.int32, shape=(batch_size, None), ragged_rank=1),
    },

分布式训练与数据混洗

由于Lance数据集是一组片段(Fragments),我们可以将这些片段分发并随机分配到不同的工作节点。

import tensorflow as tf
from lance.tf.data import from_lance, lance_fragments

world_size = 32
rank = 10
seed = 123  #
epoch = 100

dataset_uri = "s3://my-bucket/my-dataset"

# Shuffle fragments distributedly.
fragments =
    lance_fragments("s3://my-bucket/my-dataset")
    .shuffling(32, seed=seed)
    .repeat(epoch)
    .enumerate()
    .filter(lambda i, _: i % world_size == rank)
    .map(lambda _, fid: fid)

ds = from_lance(
    uri,
    columns=["image", "label"],
    fragments=fragments,
    batch_size=32
    )
for batch in ds:
    print(batch)

警告

在多进程环境下,您可能不应使用fork方式,因为lance内部采用多线程机制,而fork与多线程的配合效果不佳。具体可参阅此讨论