跳转到内容

Pandas与PyArrow

由于Lance构建在Apache Arrow之上,LanceDB与Python数据生态系统(包括Pandas和PyArrow)紧密集成。下面展示了一个典型工作流程中的步骤序列。

创建数据集

首先,我们需要连接到LanceDB数据库。

import lancedb

uri = "data/sample-lancedb"
db = lancedb.connect(uri)
import lancedb

uri = "data/sample-lancedb"
async_db = await lancedb.connect_async(uri)

我们可以直接将Pandas的DataFrame加载到LanceDB中。

import pandas as pd

data = pd.DataFrame(
    {
        "vector": [[3.1, 4.1], [5.9, 26.5]],
        "item": ["foo", "bar"],
        "price": [10.0, 20.0],
    }
)
table = db.create_table("pd_table", data=data)
import pandas as pd

data = pd.DataFrame(
    {
        "vector": [[3.1, 4.1], [5.9, 26.5]],
        "item": ["foo", "bar"],
        "price": [10.0, 20.0],
    }
)
await async_db.create_table("pd_table_async", data=data)

类似于pyarrow.write_dataset()方法,LanceDB的db.create_table()可以接受多种形式的数据。

如果你的数据集大小超过内存容量,可以使用Iterator[pyarrow.RecordBatch]来惰性加载数据创建表:

from typing import Iterable

import pyarrow as pa

def make_batches() -> Iterable[pa.RecordBatch]:
    for i in range(5):
        yield pa.RecordBatch.from_arrays(
            [
                pa.array([[3.1, 4.1], [5.9, 26.5]]),
                pa.array(["foo", "bar"]),
                pa.array([10.0, 20.0]),
            ],
            ["vector", "item", "price"],
        )


schema = pa.schema(
    [
        pa.field("vector", pa.list_(pa.float32())),
        pa.field("item", pa.utf8()),
        pa.field("price", pa.float32()),
    ]
)
table = db.create_table("iterable_table", data=make_batches(), schema=schema)
from typing import Iterable

import pyarrow as pa

def make_batches() -> Iterable[pa.RecordBatch]:
    for i in range(5):
        yield pa.RecordBatch.from_arrays(
            [
                pa.array([[3.1, 4.1], [5.9, 26.5]]),
                pa.array(["foo", "bar"]),
                pa.array([10.0, 20.0]),
            ],
            ["vector", "item", "price"],
        )


schema = pa.schema(
    [
        pa.field("vector", pa.list_(pa.float32())),
        pa.field("item", pa.utf8()),
        pa.field("price", pa.float32()),
    ]
)
await async_db.create_table(
    "iterable_table_async", data=make_batches(), schema=schema
)

您可以在入门指南API部分找到创建LanceDB数据集的详细说明。

我们现在可以通过LanceDB Python API执行相似性搜索。

# Open the table previously created.
table = db.open_table("pd_table")

query_vector = [100, 100]
# Pandas DataFrame
df = table.search(query_vector).limit(1).to_pandas()
print(df)
# Open the table previously created.
async_tbl = await async_db.open_table("pd_table_async")

query_vector = [100, 100]
# Pandas DataFrame
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
print(df)
    vector     item  price    _distance
0  [5.9, 26.5]  bar   20.0  14257.05957

如果有一个简单的过滤条件,直接向LanceDB的search方法提供where子句会更快。 对于更复杂的过滤或聚合操作,您可以在执行搜索后使用底层的DataFrame方法。

# Apply the filter via LanceDB
results = table.search([100, 100]).where("price < 15").to_pandas()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"

# Apply the filter via Pandas
df = results = table.search([100, 100]).to_pandas()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
# Apply the filter via LanceDB
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"

# Apply the filter via Pandas
df = results = await (await async_tbl.search([100, 100])).to_pandas()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"