使用AutoMM进行文本语义搜索

Open In Colab Open In SageMaker Studio Lab

1. 语义嵌入简介

语义嵌入是现代搜索技术背后的主要驱动力之一。与直接通过词频(例如,BM25)将查询与候选文档匹配不同,语义搜索算法首先将文本\(x\)转换为特征向量\(\phi(x)\),然后使用在该向量空间中定义的距离度量来比较相似性。这些特征向量,被称为“向量嵌入”,通常是在大型文本语料库上进行端到端训练的,以便它们能够编码文本的语义含义。例如,同义词被嵌入到向量空间的相似区域,并且词之间的关系通常通过代数操作揭示(参见图1的示例)。由于这些原因,文本的向量嵌入也被称为语义嵌入。通过查询和搜索候选文档的语义嵌入,搜索算法通常可以简化为寻找最相似的向量。这种新的搜索方法被称为语义搜索

Similar sentences have similar embeddings. Image from Medium

使用语义嵌入进行搜索问题相比经典的信息检索方法(例如,词袋模型或TF/IDF)有三个主要优势。首先,它返回的候选结果是根据文本的含义相关的,而不是相似的词汇使用。这有助于发现以非常不同的方式描述的转述文本和相似概念。其次,语义搜索通常在计算上更高效。候选的向量嵌入可以预先计算并存储在数据结构中。高度可扩展的草图技术,如局部敏感哈希(LSH)和最大内积搜索(MIPS),可用于在嵌入空间中高效地找到相似的向量。最后但同样重要的是,语义嵌入方法使我们能够直接将相同的搜索算法推广到文本之外,例如多模态搜索。例如,我们能否使用文本查询来搜索没有文本注释的图像?我们能否使用图像查询来搜索网站?通过语义搜索,可以简单地使用这些多模态对象的最合适的向量嵌入,并使用包含文本和图像的数据集联合训练嵌入。

本教程为您提供了一个温和的入门点,用于部署AutoMM进行语义搜索。

%%capture
!pip3 install ir_datasets
import ir_datasets
import pandas as pd
pd.set_option('display.max_colwidth', None)

2. 数据集

在本教程中,我们将使用来自ir_datasets包的NF Corpus(营养事实)数据集。 我们还将查询数据、文档数据及其相关性数据转换为数据框。

%%capture
dataset = ir_datasets.load("beir/nfcorpus/test")

# prepare dataset
doc_data = pd.DataFrame(dataset.docs_iter())
query_data = pd.DataFrame(dataset.queries_iter())
labeled_data = pd.DataFrame(dataset.qrels_iter())
label_col = "relevance"
query_id_col = "query_id"
doc_id_col = "doc_id"
text_col = "text"
id_mappings={query_id_col: query_data.set_index(query_id_col)[text_col], doc_id_col: doc_data.set_index(doc_id_col)[text_col]}

标记的数据包含查询ID、文档ID及其相关性分数。

labeled_data.head()
query_id doc_id relevance iteration
0 PLAIN-2 MED-2427 2 0
1 PLAIN-2 MED-10 2 0
2 PLAIN-2 MED-2429 2 0
3 PLAIN-2 MED-2430 2 0
4 PLAIN-2 MED-2431 2 0

查询数据存储查询ID及其对应的查询内容。

query_data.head()
query_id text url
0 PLAIN-2 Do Cholesterol Statin Drugs Cause Breast Cancer? http://nutritionfacts.org/2015/07/16/do-cholesterol-statin-drugs-cause-breast-cancer/
1 PLAIN-12 Exploiting Autophagy to Live Longer http://nutritionfacts.org/2015/06/11/exploiting-autophagy-to-live-longer/
2 PLAIN-23 How to Reduce Exposure to Alkylphenols Through Your Diet http://nutritionfacts.org/2015/04/28/how-to-reduce-exposure-to-alkylphenols-through-your-diet/
3 PLAIN-33 What’s Driving America’s Obesity Problem? http://nutritionfacts.org/2015/03/24/whats-driving-americas-obesity-problem/
4 PLAIN-44 Who Should be Careful About Curcumin? http://nutritionfacts.org/2015/02/12/who-should-be-careful-about-curcumin/

我们需要移除在搜索中未使用的网址。

query_data = query_data.drop("url", axis=1)
query_data.head()
query_id text
0 PLAIN-2 Do Cholesterol Statin Drugs Cause Breast Cancer?
1 PLAIN-12 Exploiting Autophagy to Live Longer
2 PLAIN-23 How to Reduce Exposure to Alkylphenols Through Your Diet
3 PLAIN-33 What’s Driving America’s Obesity Problem?
4 PLAIN-44 Who Should be Careful About Curcumin?

文档数据包含文档ID以及相应的内容。

doc_data.head(1)
doc_id text title url
0 MED-10 Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor characteristics, and treatment selection, both post-diagnostic and pre-diagnostic statin use were associated with lowered risk of breast cancer death (HR 0.46, 95% CI 0.38–0.55 and HR 0.54, 95% CI 0.44–0.67, respectively). The risk decrease by post-diagnostic statin use was likely affected by healthy adherer bias; that is, the greater likelihood of dying cancer patients to discontinue statin use as the association was not clearly dose-dependent and observed already at low-dose/short-term use. The dose- and time-dependence of the survival benefit among pre-diagnostic statin users suggests a possible causal effect that should be evaluated further in a clinical trial testing statins’ effect on survival in breast cancer patients. Statin Use and Breast Cancer Survival: A Nationwide Cohort Study from Finland http://www.ncbi.nlm.nih.gov/pubmed/25329299

与查询数据类似,我们移除了url列。我们还需要将所有有效的文本连接到一个列中。

doc_data[text_col] = doc_data[[text_col, "title"]].apply(" ".join, axis=1)
doc_data = doc_data.drop(["title", "url"], axis=1)
doc_data.head(1)
doc_id text
0 MED-10 Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor characteristics, and treatment selection, both post-diagnostic and pre-diagnostic statin use were associated with lowered risk of breast cancer death (HR 0.46, 95% CI 0.38–0.55 and HR 0.54, 95% CI 0.44–0.67, respectively). The risk decrease by post-diagnostic statin use was likely affected by healthy adherer bias; that is, the greater likelihood of dying cancer patients to discontinue statin use as the association was not clearly dose-dependent and observed already at low-dose/short-term use. The dose- and time-dependence of the survival benefit among pre-diagnostic statin users suggests a possible causal effect that should be evaluated further in a clinical trial testing statins’ effect on survival in breast cancer patients. Statin Use and Breast Cancer Survival: A Nationwide Cohort Study from Finland

数据集中有323个查询,3633个文档和12334个相关性分数。

3. NDCG 评估

用户最关注的是第一个结果,然后是第二个,依此类推。 因此,对于排名靠前的结果,精确度最为重要。 在本教程中,我们使用归一化折损累计增益(NDCG)来衡量排名性能。

3.1 CG, DCG, IDCG 和 NDCG 公式

为了理解NDCG指标,我们首先必须理解CG(累积增益)和DCG(折扣累积增益),以及我们在使用DCG及其相关度量时所做的两个假设:

  1. 高度相关的文档在搜索引擎结果列表中较早出现时更有用。

  2. 高度相关的文档比略微相关的文档更有用,而略微相关的文档又比不相关的文档更有用

首先,原始的累积增益(CG),它将相关性分数(\(rel\))累加到指定的排名位置\(p\)

\[ \mathrm{CG}_p = \sum_{i=1}^p \mathrm{rel}_i. \]

然后,折扣累积增益(DCG),它根据每个相关性得分在结果中的位置对其进行对数惩罚:

\[ \mathrm{DCG}_p = \sum_{i=1}^p \frac{\mathrm{rel}_i}{\log_2(i + 1)}. \]

接下来是理想DCG(IDCG),它是基于给定评分的最佳可能结果的DCG:

\[ \mathrm{IDCG}_p = \sum_{i=1}^{|\mathrm{REL}_p|} \frac{\mathrm{rel}_i}{\log_2(i + 1)}. \]

其中 \(|mathrm{REL}_p|\) 是语料库中相关文档的列表(按相关性排序),直到位置 \(p\)

最后是NDCG

\[ \mathrm{NDCG}_p = \frac{\mathrm{DCG}_p}{\mathrm{IDCG}_p}. \]

我们提供了一个实用函数来计算排名分数。此外,我们还支持在不同截断值下测量NDCG。

from autogluon.multimodal.utils import compute_ranking_score
cutoffs = [5, 10, 20]
/home/ci/opt/venv/lib/python3.11/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \

4. 使用BM25

BM25(或Okapi BM25)是OpenSearch目前用于评分文档与查询相关性的流行排名算法。 在本教程中,我们将使用BM25的NDCG分数作为基线。

4.1 定义公式

\[ score_{BM25} = \sum_i^n \mathrm{IDF}(q_i) \frac{f(q_i, D) \cdot (k1 + 1)}{f(q_i, D) + k1 \cdot (1 - b + b \cdot \frac{fieldLen}{avgFieldLen})}\]

其中 \(\mathrm{IDF}(q_i)\) 是第 \(i^{th}\) 个查询项的逆文档频率,BM25 用于此部分的实际公式为:

\[ \log(1 + \frac{docCount - f(q_i) + 0.5)}{f(q_i) + 0.5}). \]

\(k1\) 是一个可调的超参数,它限制单个查询词对给定文档评分的影响程度。在 ElasticSearch 中,它默认为 1.2。

\(b\) 是另一个超参数变量,它决定了文档长度与语料库中平均文档长度相比的影响。在 ElasticSearch 中,它默认为 0.75。

在本教程中,我们将使用rank_bm25包,以避免从头开始实现算法的复杂性。

4.2 定义函数

%%capture
!pip3 install rank_bm25
from collections import defaultdict
import string
import nltk
import numpy as np
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi

nltk.download('stopwords')
nltk.download('punkt')

def tokenize_corpus(corpus):
    stop_words = set(stopwords.words("english") + list(string.punctuation))
    
    tokenized_docs = []
    for doc in corpus:
        tokens = nltk.word_tokenize(doc.lower())
        tokenized_doc = [w for w in tokens if w not in stop_words and len(w) > 2]
        tokenized_docs.append(tokenized_doc)
    return tokenized_docs

def rank_documents_bm25(queries_text, queries_id, docs_id, top_k, bm25):
    tokenized_queries = tokenize_corpus(queries_text)
    
    results = {qid: {} for qid in queries_id}
    for query_idx, query in enumerate(tokenized_queries):
        scores = bm25.get_scores(query)
        scores_top_k_idx = np.argsort(scores)[::-1][:top_k]
        for doc_idx in scores_top_k_idx:
            results[queries_id[query_idx]][docs_id[doc_idx]] = float(scores[doc_idx])
    return results

def get_qrels(dataset):
    """
    Get the ground truth of relevance score for all queries
    """
    qrel_dict = defaultdict(dict)
    for qrel in dataset.qrels_iter():
        qrel_dict[qrel.query_id][qrel.doc_id] = qrel.relevance
    return qrel_dict

def evaluate_bm25(doc_data, query_data, qrel_dict, cutoffs):
    
    tokenized_corpus = tokenize_corpus(doc_data[text_col].tolist())
    bm25_model = BM25Okapi(tokenized_corpus, k1=1.2, b=0.75)
    
    results = rank_documents_bm25(query_data[text_col].tolist(), query_data[query_id_col].tolist(), doc_data[doc_id_col].tolist(), max(cutoffs), bm25_model)
    ndcg = compute_ranking_score(results=results, qrel_dict=qrel_dict, metrics=["ndcg"], cutoffs=cutoffs)
    
    return ndcg
[nltk_data] Downloading package stopwords to /home/ci/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /home/ci/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
qrel_dict = get_qrels(dataset)
evaluate_bm25(doc_data, query_data, qrel_dict, cutoffs)
{'ndcg@5': 0.33858, 'ndcg@10': 0.2983, 'ndcg@20': 0.26408}

5. 使用AutoMM

AutoMM 提供了易于使用的 API 来评估排名性能、提取嵌入并进行语义搜索。

5.1 初始化预测器

对于文本数据,我们可以使用问题类型text_similarity初始化MultiModalPredictor。 我们需要在labeled_data数据框中指定queryresponselabel对应的列名。

%%capture
from autogluon.multimodal import MultiModalPredictor

predictor = MultiModalPredictor(
        query=query_id_col,
        response=doc_id_col,
        label=label_col,
        problem_type="text_similarity",
        hyperparameters={"model.hf_text.checkpoint_name": "sentence-transformers/all-MiniLM-L6-v2"}
    )

5.2 评估排名

使用evaluate API评估排名性能非常简单。 在评估过程中,预测器会自动提取嵌入,计算余弦相似度,对结果进行排名,并计算分数。

predictor.evaluate(
        labeled_data,
        query_data=query_data[[query_id_col]],
        response_data=doc_data[[doc_id_col]],
        id_mappings=id_mappings,
        cutoffs=cutoffs,
        metrics=["ndcg"],
    )
/home/ci/opt/venv/lib/python3.11/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
  return bound(*args, **kwds)
{'ndcg@5': 0.33672, 'ndcg@10': 0.30891, 'ndcg@20': 0.28207}

我们可以发现BM25性能的显著改进。

5.4 提取嵌入

提取嵌入对于将模型部署到行业搜索引擎中非常重要。通常,系统会离线提取数据库项目的嵌入。在在线搜索期间,它只需要编码查询数据,然后高效地将查询嵌入与保存的数据库嵌入进行匹配。

query_embeds = predictor.extract_embedding(query_data[[query_id_col]], id_mappings=id_mappings, as_tensor=True)
doc_embeds = predictor.extract_embedding(doc_data[[doc_id_col]], id_mappings=id_mappings, as_tensor=True)

6. 混合BM25

我们提出了一种新的搜索排名方法,称为Hybrid BM25,它结合了BM25和语义嵌入进行评分。关键思想是使用BM25作为第一阶段的检索方法(例如,它为每个查询召回1000个文档),然后使用预训练的语言模型(PLM)对所有召回的文档(1000个文档)进行评分。

然后我们使用以下公式计算的分数对检索到的文档进行重新排序:

\[ score = \beta * normalized\_BM25 + ( 1 - \beta) * score\_of\_plm \]

其中

\[ normalized\_BM25(q_i, D_j) = \frac{\textsf{BM25}(q_i,D_j) - \min_{a\in \mathcal{Q},b\in\mathcal{D}}(\textsf{BM25}(a,b))}{\max_{a\in \mathcal{Q},b\in\mathcal{D}}(\textsf{BM25}(a,b)) - \min_{a\in \mathcal{Q},b\in\mathcal{D}}(\textsf{BM25}(a,b))},\]

并且 \(\beta\) 是一个可调参数,在我们的教程中我们将默认设置为 \(0.3\)

6.1 定义函数

import torch
from autogluon.multimodal.utils import compute_semantic_similarity

def hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, top_k, beta):
    # Recall documents with BM25 scores
    tokenized_corpus = tokenize_corpus(doc_data[text_col].tolist())
    bm25_model = BM25Okapi(tokenized_corpus, k1=1.2, b=0.75)
    bm25_scores = rank_documents_bm25(query_data[text_col].tolist(), query_data[query_id_col].tolist(), doc_data[doc_id_col].tolist(), recall_num, bm25_model)
    
    all_bm25_scores = [score for scores in bm25_scores.values() for score in scores.values()]
    max_bm25_score = max(all_bm25_scores)
    min_bm25_score = min(all_bm25_scores)

    q_embeddings = {qid: embed for qid, embed in zip(query_data[query_id_col].tolist(), query_embeds)}
    d_embeddings = {did: embed for did, embed in zip(doc_data[doc_id_col].tolist(), doc_embeds)}
    
    query_ids = query_data[query_id_col].tolist()
    results = {qid: {} for qid in query_ids}
    for idx, qid in enumerate(query_ids):
        rec_docs = bm25_scores[qid]
        rec_doc_emb = [d_embeddings[doc_id] for doc_id in rec_docs.keys()]
        rec_doc_id = [doc_id for doc_id in rec_docs.keys()]
        rec_doc_emb = torch.stack(rec_doc_emb)
        scores = compute_semantic_similarity(q_embeddings[qid], rec_doc_emb)
        scores[torch.isnan(scores)] = -1
        top_k_values, top_k_idxs = torch.topk(
            scores,
            min(top_k + 1, len(scores[0])),
            dim=1,
            largest=True,
            sorted=False,
        )

        for doc_idx, score in zip(top_k_idxs[0], top_k_values[0]):
            doc_id = rec_doc_id[int(doc_idx)]
            # Hybrid scores from BM25 and cosine similarity of embeddings
            results[qid][doc_id] = \
                (1 - beta) * float(score.numpy()) \
                + beta * (bm25_scores[qid][doc_id] - min_bm25_score) / (max_bm25_score - min_bm25_score)
    
    return results


def evaluate_hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, beta, cutoffs):
    results = hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, max(cutoffs), beta)
    ndcg = compute_ranking_score(results=results, qrel_dict=qrel_dict, metrics=["ndcg"], cutoffs=cutoffs)
    return ndcg
recall_num = 1000
beta = 0.3
query_embeds = predictor.extract_embedding(query_data[[query_id_col]], id_mappings=id_mappings, as_tensor=True)
doc_embeds = predictor.extract_embedding(doc_data[[doc_id_col]], id_mappings=id_mappings, as_tensor=True)
evaluate_hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, beta, cutoffs)
{'ndcg@5': 0.36948, 'ndcg@10': 0.33297, 'ndcg@20': 0.29236}

我们能够改进排名分数,超越了简单的BM25。

7. 总结

在本教程中,我们演示了如何使用AutoMM进行语义搜索,并展示了其相对于经典BM25的明显改进。我们通过结合BM25和AutoMM(混合BM25)进一步提高了排名分数。