跳转到内容

用户自定义嵌入函数

要使用您自定义的嵌入函数,只需遵循以下两个简单步骤:

  1. 通过实现EmbeddingFunction接口来创建您的嵌入函数
  2. 将您的嵌入函数注册到全局EmbeddingFunctionRegistry中。

让我们看看实际效果如何。

EmbeddingFunctionEmbeddingFunctionRegistry 负责处理将模式和模型信息序列化为元数据的底层细节。要构建自定义嵌入函数,您无需担心这些细枝末节 - 只需专注于设置模型,剩下的交给LanceDB即可。

TextEmbeddingFunction 接口

还有另一个可选的抽象层可用:TextEmbeddingFunction。如果您的模型本质上不是多模态的,只需要处理文本,您可以使用这个抽象层。在这种情况下,源字段和向量字段在向量化方面的工作是相同的,因此您只需要设置模型,其余部分由TextEmbeddingFunction处理。您可以在类参考中阅读更多关于该类及其属性的信息。

让我们实现SentenceTransformerEmbeddings类。你只需要实现generate_embeddings()ndims函数来处理预期的输入类型,并将该类注册到全局的EmbeddingFunctionRegistry中。

from lancedb.embeddings import register
from lancedb.util import attempt_import_or_raise

@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
    name: str = "all-MiniLM-L6-v2"
    # set more default instance vars like device, etc.

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._ndims = None

    def generate_embeddings(self, texts):
        return self._embedding_model().encode(list(texts), ...).tolist()

    def ndims(self):
        if self._ndims is None:
            self._ndims = len(self.generate_embeddings("foo")[0])
        return self._ndims

    @cached(cache={})
    def _embedding_model(self):
        return sentence_transformers.SentenceTransformer(name)
import * as lancedb from "@lancedb/lancedb";
import {
  LanceSchema,
  TextEmbeddingFunction,
  getRegistry,
  register,
} from "@lancedb/lancedb/embedding";

@register("sentence-transformers")
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
  name = "Xenova/all-miniLM-L6-v2";
  #ndims!: number;
  extractor!: FeatureExtractionPipeline;

  async init() {
    this.extractor = await pipeline("feature-extraction", this.name, {
      dtype: "fp32",
    });
    this.#ndims = await this.generateEmbeddings(["hello"]).then(
      (e) => e[0].length,
    );
  }

  ndims() {
    return this.#ndims;
  }

  toJSON() {
    return {
      name: this.name,
    };
  }
  async generateEmbeddings(texts: string[]) {
    const output = await this.extractor(texts, {
      pooling: "mean",
      normalize: true,
    });
    return output.tolist();
  }
}

这是我们实现的SentenceTransformerEmbeddings的精简版本,移除了某些优化和默认设置。

使用敏感密钥以防止泄露机密信息

为防止泄露API密钥等敏感信息,您应将嵌入函数的任何敏感参数添加到sensitive_keys() / getSensitiveKeys()方法的输出中。这样可以避免用户意外使用硬编码的密钥实例化嵌入函数。

现在你可以使用这个嵌入函数来创建你的表结构,就这样!之后你就可以直接摄取数据并运行查询,而无需手动对输入进行向量化处理。

from lancedb.pydantic import LanceModel, Vector

registry = EmbeddingFunctionRegistry.get_instance()
stransformer = registry.get("sentence-transformers").create()

class TextModelSchema(LanceModel):
    vector: Vector(stransformer.ndims) = stransformer.VectorField()
    text: str = stransformer.SourceField()

tbl = db.create_table("table", schema=TextModelSchema)

tbl.add(pd.DataFrame({"text": ["halo", "world"]}))
result = tbl.search("world").limit(5)
const registry = getRegistry();

const sentenceTransformer = await registry
  .get<SentenceTransformersEmbeddings>("sentence-transformers")!
  .create();

const schema = LanceSchema({
  vector: sentenceTransformer.vectorField(),
  text: sentenceTransformer.sourceField(),
});

const db = await lancedb.connect(databaseDir);
const table = await db.createEmptyTable("table", schema, {
  mode: "overwrite",
});

await table.add([{ text: "hello" }, { text: "world" }]);

const results = await table.search("greeting").limit(1).toArray();

注意

如果您愿意或需要,您始终可以直接实现EmbeddingFunction接口,而TextEmbeddingFunction通过为特定文本用例设置样板代码,使这一过程变得更加简单快捷

多模态嵌入函数示例

你也可以使用EmbeddingFunction接口来实现更复杂的工作流程,例如支持多模态嵌入函数。

LanceDB 实现了支持多模态搜索的 OpenClipEmeddingFunction 类。以下是您可以参考的实现,用于构建自己的多模态嵌入函数。

@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
    name: str = "ViT-B-32"
    pretrained: str = "laion2b_s34b_b79k"
    device: str = "cpu"
    batch_size: int = 64
    normalize: bool = True
    _model = PrivateAttr()
    _preprocess = PrivateAttr()
    _tokenizer = PrivateAttr()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
        model, _, preprocess = open_clip.create_model_and_transforms(
            self.name, pretrained=self.pretrained
        )
        model.to(self.device)
        self._model, self._preprocess = model, preprocess
        self._tokenizer = open_clip.get_tokenizer(self.name)
        self._ndims = None

    def ndims(self):
        if self._ndims is None:
            self._ndims = self.generate_text_embeddings("foo").shape[0]
        return self._ndims

    def compute_query_embeddings(
        self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
    ) -> List[np.ndarray]:
        """
        Compute the embeddings for a given user query

        Parameters
        ----------
        query : Union[str, PIL.Image.Image]
            The query to embed. A query can be either text or an image.
        """
        if isinstance(query, str):
            return [self.generate_text_embeddings(query)]
        else:
            PIL = attempt_import_or_raise("PIL", "pillow")
            if isinstance(query, PIL.Image.Image):
                return [self.generate_image_embedding(query)]
            else:
                raise TypeError("OpenClip supports str or PIL Image as query")

    def generate_text_embeddings(self, text: str) -> np.ndarray:
        torch = attempt_import_or_raise("torch")
        text = self.sanitize_input(text)
        text = self._tokenizer(text)
        text.to(self.device)
        with torch.no_grad():
            text_features = self._model.encode_text(text.to(self.device))
            if self.normalize:
                text_features /= text_features.norm(dim=-1, keepdim=True)
            return text_features.cpu().numpy().squeeze()

    def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
        """
        Sanitize the input to the embedding function.
        """
        if isinstance(images, (str, bytes)):
            images = [images]
        elif isinstance(images, pa.Array):
            images = images.to_pylist()
        elif isinstance(images, pa.ChunkedArray):
            images = images.combine_chunks().to_pylist()
        return images

    def compute_source_embeddings(
        self, images: IMAGES, *args, **kwargs
    ) -> List[np.array]:
        """
        Get the embeddings for the given images
        """
        images = self.sanitize_input(images)
        embeddings = []
        for i in range(0, len(images), self.batch_size):
            j = min(i + self.batch_size, len(images))
            batch = images[i:j]
            embeddings.extend(self._parallel_get(batch))
        return embeddings

    def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
        """
        Issue concurrent requests to retrieve the image data
        """
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(self.generate_image_embedding, image)
                for image in images
            ]
            return [future.result() for future in futures]

    def generate_image_embedding(
        self, image: Union[str, bytes, "PIL.Image.Image"]
    ) -> np.ndarray:
        """
        Generate the embedding for a single image

        Parameters
        ----------
        image : Union[str, bytes, PIL.Image.Image]
            The image to embed. If the image is a str, it is treated as a uri.
            If the image is bytes, it is treated as the raw image bytes.
        """
        torch = attempt_import_or_raise("torch")
        # TODO handle retry and errors for https
        image = self._to_pil(image)
        image = self._preprocess(image).unsqueeze(0)
        with torch.no_grad():
            return self._encode_and_normalize_image(image)

    def _to_pil(self, image: Union[str, bytes]):
        PIL = attempt_import_or_raise("PIL", "pillow")
        if isinstance(image, bytes):
            return PIL.Image.open(io.BytesIO(image))
        if isinstance(image, PIL.Image.Image):
            return image
        elif isinstance(image, str):
            parsed = urlparse.urlparse(image)
            # TODO handle drive letter on windows.
            if parsed.scheme == "file":
                return PIL.Image.open(parsed.path)
            elif parsed.scheme == "":
                return PIL.Image.open(image if os.name == "nt" else parsed.path)
            elif parsed.scheme.startswith("http"):
                return PIL.Image.open(io.BytesIO(url_retrieve(image)))
            else:
                raise NotImplementedError("Only local and http(s) urls are supported")

    def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
        """
        encode a single image tensor and optionally normalize the output
        """
        image_features = self._model.encode_image(image_tensor)
        if self.normalize:
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy().squeeze()

即将推出!请关注此issue以跟踪最新状态!