Tensorflow 集成¶
Lance可以作为常规的tf.data.Dataset在Tensorflow中使用。
警告
此功能处于实验阶段,API接口未来可能会发生变化。
从Lance读取数据¶
使用lance.tf.data.from_lance(),您可以轻松创建一个tf.data.Dataset。
import tensorflow as tf
import lance
# Create tf dataset
ds = lance.tf.data.from_lance("s3://my-bucket/my-dataset")
# Chain tf dataset with other tf primitives
for batch in ds.shuffling(32).map(lambda x: tf.io.decode_png(x["image"])):
print(batch)
基于Lance 列式存储格式,使用lance.tf.data.from_lance()支持高效的列选择、过滤等功能。
ds = lance.tf.data.from_lance(
"s3://my-bucket/my-dataset",
columns=["image", "label"],
filter="split = 'train' AND collected_time > timestamp '2020-01-01'",
batch_size=256)
默认情况下,Lance会从投影列中推断Tensor规范。您也可以手动指定tf.TensorSpec。
batch_size = 256
ds = lance.tf.data.from_lance(
"s3://my-bucket/my-dataset",
columns=["image", "labels"],
batch_size=batch_size,
output_signature={
"image": tf.TensorSpec(shape=(), dtype=tf.string),
"labels": tf.RaggedTensorSpec(
dtype=tf.int32, shape=(batch_size, None), ragged_rank=1),
},
分布式训练与数据混洗¶
由于Lance数据集是一组片段(Fragments),我们可以将这些片段分发并随机分配到不同的工作节点。
import tensorflow as tf
from lance.tf.data import from_lance, lance_fragments
world_size = 32
rank = 10
seed = 123 #
epoch = 100
dataset_uri = "s3://my-bucket/my-dataset"
# Shuffle fragments distributedly.
fragments =
lance_fragments("s3://my-bucket/my-dataset")
.shuffling(32, seed=seed)
.repeat(epoch)
.enumerate()
.filter(lambda i, _: i % world_size == rank)
.map(lambda _, fid: fid)
ds = from_lance(
uri,
columns=["image", "label"],
fragments=fragments,
batch_size=32
)
for batch in ds:
print(batch)
警告
在多进程环境下,您可能不应使用fork方式,因为lance内部采用多线程机制,而fork与多线程的配合效果不佳。具体可参阅此讨论。