2022年3月10日

使用嵌入和最近邻搜索进行推荐

, ,

推荐内容在网络上随处可见。

  • '购买了该商品?试试这些类似的商品。'
  • '喜欢那本书吗?试试这些类似的书籍。'
  • '这不是您要找的帮助页面?试试这些类似的页面。'

本笔记本演示了如何利用嵌入技术来寻找相似项目进行推荐。具体而言,我们使用AG的新闻文章语料库作为数据集。

我们的模型将回答这个问题:给定一篇文章,哪些其他文章与它最相似?

import pandas as pd
import pickle

from utils.embeddings_utils import (
    get_embedding,
    distances_from_embeddings,
    tsne_components_from_embeddings,
    chart_from_components,
    indices_of_nearest_neighbors_from_distances,
)

EMBEDDING_MODEL = "text-embedding-3-small"

2. 加载数据

接下来,让我们加载AG新闻数据并查看其内容。

# load data (full dataset available at http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
dataset_path = "data/AG_news_samples.csv"
df = pd.read_csv(dataset_path)

n_examples = 5
df.head(n_examples)
标题 描述 整数标签 标签
0 世界简报 英国:布莱尔警告气候威胁 首相... 1 世界
1 英伟达在主板集成防火墙(PC世界报道...) PC世界 - 即将推出的芯片组将内置... 4 科技
2 希腊与中国媒体中的奥运喜悦 希腊报纸展现出一种自豪与宽慰交织的复杂情绪... 2 体育
3 U2版带图片功能的iPod 加州圣何塞 -- 苹果电脑公司(报价,Cha... 4 科技/技术
4 梦想工厂 任何产品、任何形状、任何尺寸——制造... 4 科技

让我们来看看这些相同的示例,但不会被省略号截断。

# print the title, description, and label of each example
for idx, row in df.head(n_examples).iterrows():
    print("")
    print(f"Title: {row['title']}")
    print(f"Description: {row['description']}")
    print(f"Label: {row['label']}")
Title: World Briefings
Description: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.
Label: World

Title: Nvidia Puts a Firewall on a Motherboard (PC World)
Description: PC World - Upcoming chip set will include built-in security features for your PC.
Label: Sci/Tech

Title: Olympic joy in Greek, Chinese press
Description: Newspapers in Greece reflect a mixture of exhilaration that the Athens Olympics proved successful, and relief that they passed off without any major setback.
Label: Sports

Title: U2 Can iPod with Pictures
Description: SAN JOSE, Calif. -- Apple Computer (Quote, Chart) unveiled a batch of new iPods, iTunes software and promos designed to keep it atop the heap of digital music players.
Label: Sci/Tech

Title: The Dream Factory
Description: Any product, any shape, any size -- manufactured on your desktop! The future is the fabricator. By Bruce Sterling from Wired magazine.
Label: Sci/Tech

3. 构建缓存以保存嵌入向量

在获取这些文章的嵌入向量之前,我们先设置一个缓存来保存生成的嵌入向量。通常来说,保存嵌入向量是个好主意,这样以后可以重复使用。如果不保存,每次重新计算时都需要再次付费。

缓存是一个字典,它将(text, model)元组映射到一个嵌入向量(浮点数列表)。该缓存保存为Python pickle文件。

# establish a cache of embeddings to avoid recomputing
# cache is a dict of tuples (text, model) -> embedding, saved as a pickle file

# set path to embedding cache
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"

# load the cache if it exists, and save a copy to disk
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

# define a function to retrieve embeddings from the cache if present, and otherwise request via the API
def embedding_from_string(
    string: str,
    model: str = EMBEDDING_MODEL,
    embedding_cache=embedding_cache
) -> list:
    """Return embedding of given string, using a cache to avoid recomputing."""
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)] = get_embedding(string, model)
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
    return embedding_cache[(string, model)]

让我们通过获取一个嵌入向量来验证它是否正常工作。

# as an example, take the first description from the dataset
example_string = df["description"].values[0]
print(f"\nExample string: {example_string}")

# print the first 10 dimensions of the embedding
example_embedding = embedding_from_string(example_string)
print(f"\nExample embedding: {example_embedding[:10]}...")
Example string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.

Example embedding: [0.0545826330780983, -0.00428084097802639, 0.04785159230232239, 0.01587914116680622, -0.03640881925821304, 0.0143799539655447, -0.014267769642174244, -0.015175441280007362, -0.002344391541555524, 0.011075624264776707]...

4. 基于嵌入向量推荐相似文章

要查找相似文章,我们可以按照以下三步计划进行:

  1. 获取所有文章描述的相似性嵌入向量
  2. 计算源标题与所有其他文章之间的距离
  3. 打印出与源标题最接近的其他文章
def print_recommendations_from_strings(
    strings: list[str],
    index_of_source_string: int,
    k_nearest_neighbors: int = 1,
    model=EMBEDDING_MODEL,
) -> list[int]:
    """Print out the k nearest neighbors of a given string."""
    # get embeddings for all strings
    embeddings = [embedding_from_string(string, model=model) for string in strings]

    # get the embedding of the source string
    query_embedding = embeddings[index_of_source_string]

    # get distances between the source embedding and other embeddings (function from utils.embeddings_utils.py)
    distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")
    
    # get indices of nearest neighbors (function from utils.utils.embeddings_utils.py)
    indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)

    # print out source string
    query_string = strings[index_of_source_string]
    print(f"Source string: {query_string}")
    # print out its k nearest neighbors
    k_counter = 0
    for i in indices_of_nearest_neighbors:
        # skip any strings that are identical matches to the starting string
        if query_string == strings[i]:
            continue
        # stop after printing out k articles
        if k_counter >= k_nearest_neighbors:
            break
        k_counter += 1

        # print out the similar strings and their distances
        print(
            f"""
        --- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) ---
        String: {strings[i]}
        Distance: {distances[i]:0.3f}"""
        )

    return indices_of_nearest_neighbors

5. 示例推荐

让我们查找与第一篇关于托尼·布莱尔的文章类似的文章。

article_descriptions = df["description"].tolist()

tony_blair_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # let's base similarity off of the article description
    index_of_source_string=0,  # articles similar to the first one about Tony Blair
    k_nearest_neighbors=5,  # 5 most similar articles
)
Source string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.

        --- Recommendation #1 (nearest neighbor 1 of 5) ---
        String: The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war.
        Distance: 0.514

        --- Recommendation #2 (nearest neighbor 2 of 5) ---
        String: THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today.
        Distance: 0.516

        --- Recommendation #3 (nearest neighbor 3 of 5) ---
        String: Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed 
        Distance: 0.546

        --- Recommendation #4 (nearest neighbor 4 of 5) ---
        String: Allowing dozens of casinos to be built in the UK would bring investment and thousands of jobs, Tony Blair says.
        Distance: 0.568

        --- Recommendation #5 (nearest neighbor 5 of 5) ---
        String: AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair.
        Distance: 0.579

相当不错!5条推荐中有4条明确提到了托尼·布莱尔,第五条是一篇来自伦敦关于气候变化的文章,这些主题可能经常与托尼·布莱尔相关联。

让我们看看我们的推荐系统在第二篇关于NVIDIA新款更安全芯片组的示例文章上表现如何。

chipset_security_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # let's base similarity off of the article description
    index_of_source_string=1,  # let's look at articles similar to the second one about a more secure chipset
    k_nearest_neighbors=5,  # let's look at the 5 most similar articles
)
Source string: PC World - Upcoming chip set will include built-in security features for your PC.

        --- Recommendation #1 (nearest neighbor 1 of 5) ---
        String: PC World - Updated antivirus software for businesses adds intrusion prevention features.
        Distance: 0.422

        --- Recommendation #2 (nearest neighbor 2 of 5) ---
        String: PC World - Symantec, McAfee hope raising virus-definition fees will move users to\  suites.
        Distance: 0.518

        --- Recommendation #3 (nearest neighbor 3 of 5) ---
        String: originally offered on notebook PCs -- to its Opteron 32- and 64-bit x86 processors for server applications. The technology will help servers to run 
        Distance: 0.522

        --- Recommendation #4 (nearest neighbor 4 of 5) ---
        String: PC World - Send your video throughout your house--wirelessly--with new gateways and media adapters.
        Distance: 0.532

        --- Recommendation #5 (nearest neighbor 5 of 5) ---
        String: Chips that help a computer's main microprocessors perform specific types of math problems are becoming a big business once again.\
        Distance: 0.532

从打印的距离值可以看出,第1条推荐结果比其他所有推荐都更接近(0.11对比0.14+)。而且第1条推荐看起来与起始文章非常相似——这是PC World另一篇关于提升计算机安全性的文章。相当不错!

附录:在更复杂的推荐系统中使用嵌入

构建推荐系统的一种更复杂方法是训练一个机器学习模型,该模型接收数十或数百个信号,例如物品流行度或用户点击数据。即使在这样的系统中,嵌入向量仍可作为推荐系统的非常有用的信号,特别是对于那些尚未获得用户数据的"冷启动"物品(例如,刚添加到目录中且尚未获得任何点击的全新产品)。

为了理解我们的最近邻推荐系统的工作原理,让我们可视化文章嵌入向量。虽然我们无法绘制每个2048维的嵌入向量,但我们可以使用t-SNEPCA等技术将嵌入压缩到2维或3维,这样就可以绘制图表了。

在可视化最近邻之前,我们先使用t-SNE对所有文章描述进行可视化。请注意,t-SNE不具有确定性,这意味着每次运行的结果可能会有所不同。

# get embeddings for all article descriptions
embeddings = [embedding_from_string(string) for string in article_descriptions]
# compress the 2048-dimensional embeddings into 2 dimensions using t-SNE
tsne_components = tsne_components_from_embeddings(embeddings)
# get the article labels for coloring the chart
labels = df["label"].tolist()

chart_from_components(
    components=tsne_components,
    labels=labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="t-SNE components of article descriptions",
)

如上图所示,即使经过高度压缩的嵌入向量也能很好地按类别对文章描述进行聚类。值得强调的是:这种聚类是在完全不了解标签本身的情况下完成的!

此外,如果仔细观察那些最明显的异常值,它们通常是由于标签错误而非嵌入质量差造成的。例如,绿色体育类别中的大多数蓝色世界新闻点实际上都是体育类报道。

接下来,我们将根据这些点是源文章、其最近邻点还是其他类型,来重新着色。

# create labels for the recommended articles
def nearest_neighbor_labels(
    list_of_indices: list[int],
    k_nearest_neighbors: int = 5
) -> list[str]:
    """Return a list of labels to color the k nearest neighbors."""
    labels = ["Other" for _ in list_of_indices]
    source_index = list_of_indices[0]
    labels[source_index] = "Source"
    for i in range(k_nearest_neighbors):
        nearest_neighbor_index = list_of_indices[i + 1]
        labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})"
    return labels


tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)
# a 2D chart of nearest neighbors of the Tony Blair article
chart_from_components(
    components=tsne_components,
    labels=tony_blair_labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="Nearest neighbors of the Tony Blair article",
    category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)

观察上面的二维图表,我们可以发现关于托尼·布莱尔的文章在世界新闻聚类中相对集中。有趣的是,虽然五个最近邻(红色标记)在高维空间中最接近,但在压缩后的二维空间中它们并非最近的点。将嵌入向量压缩至二维会丢失大量信息,因此二维空间中的最近邻似乎不如完整嵌入空间中的相关性强。

# a 2D chart of nearest neighbors of the chipset security article
chart_from_components(
    components=tsne_components,
    labels=chipset_security_labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="Nearest neighbors of the chipset security article",
    category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)

对于芯片组安全示例,在全嵌入空间中的4个最近邻在此压缩的2D可视化中仍保持为最近邻。尽管第五个在全嵌入空间中更接近,但在图中显示为距离更远。

如果需要,您还可以使用函数chart_from_components_3D创建嵌入的交互式3D图表。(这样做需要使用n_components=3重新计算t-SNE组件。)