Shortcuts

TorchRec 概念

在本节中,我们将学习TorchRec的关键概念,这些概念旨在使用PyTorch优化大规模推荐系统。我们将详细了解每个概念的工作原理,并回顾它如何与TorchRec的其他部分一起使用。

TorchRec 有其模块特定的输入/输出数据类型,以高效表示稀疏特征,包括:

  • JaggedTensor: 一个围绕长度/偏移量和值的张量的包装器,用于单一稀疏特征。

  • KeyedJaggedTensor: 高效地表示多个稀疏特征,可以将其视为多个JaggedTensor

  • KeyedTensor: 一个围绕 torch.Tensor 的包装器,允许通过键访问张量值。

以高性能和高效为目标,标准的torch.Tensor在表示稀疏数据时效率非常低。TorchRec引入了这些新的数据类型,因为它们提供了稀疏输入数据的高效存储和表示。正如你稍后将看到的,KeyedJaggedTensor使得在分布式环境中输入数据的通信非常高效,这是TorchRec提供的关键性能优势之一。

在端到端的训练循环中,TorchRec 包含以下主要组件:

  • 规划器: 接收嵌入表的配置、环境设置,并为模型生成优化的分片计划。

  • Sharder: 根据分片计划使用不同的分片策略对模型进行分片,包括数据并行、表级、行级、表级行级、列级和表级列级分片。

  • DistributedModelParallel: 结合了分片器、优化器,并提供了以分布式方式训练模型的入口点。

JaggedTensor

一个JaggedTensor通过长度、值和偏移量来表示稀疏特征。它被称为“锯齿状”是因为它有效地表示了具有可变长度序列的数据。相比之下,标准的torch.Tensor假设每个序列具有相同的长度,这在实际数据中往往不是这样。JaggedTensor有助于表示这种数据而无需填充,使其非常高效。

关键组件:

  • Lengths: 一个整数列表,表示每个实体的元素数量。

  • Offsets: 一个整数列表,表示扁平化值张量中每个序列的起始索引。这些提供了长度的替代方案。

  • Values: 一个一维张量,包含每个实体的实际值,连续存储。

这里是一个简单的示例,展示了每个组件的样子:

# User interactions:
# - User 1 interacted with 2 items
# - User 2 interacted with 3 items
# - User 3 interacted with 1 item
lengths = [2, 3, 1]
offsets = [0, 2, 5]  # Starting index of each user's interactions
values = torch.Tensor([101, 102, 201, 202, 203, 301])  # Item IDs interacted with
jt = JaggedTensor(lengths=lengths, values=values)
# OR
jt = JaggedTensor(offsets=offsets, values=values)

KeyedJaggedTensor

一个KeyedJaggedTensor通过引入键(通常是特征名称)来扩展JaggedTensor的功能,以标记不同的特征组,例如用户特征和物品特征。这是EmbeddingBagCollectionEmbeddingCollectionforward中使用的数据类型,因为它们用于表示表中的多个特征。

一个KeyedJaggedTensor有一个隐含的批量大小,这是特征数量除以lengths张量的长度。下面的例子有一个批量大小为2。与JaggedTensor类似,offsetslengths以相同的方式工作。你也可以通过从KeyedJaggedTensor访问键来访问特征的lengthsoffsetsvalues

keys = ["user_features", "item_features"]
# Lengths of interactions:
# - User features: 2 users, with 2 and 3 interactions respectively
# - Item features: 2 items, with 1 and 2 interactions respectively
lengths = [2, 3, 1, 2]
values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
# Access the features by key
print(kjt["user_features"])
# Outputs user features
print(kjt["item_features"])

计划器

TorchRec规划器帮助确定模型的最佳分片配置。它评估了分片嵌入表的多种可能性,并优化性能。规划器执行以下操作:

  • 评估硬件的内存限制。

  • 根据内存获取(如嵌入查找)估算计算需求。

  • 解决数据特定的因素。

  • 考虑其他硬件特性,如带宽,以生成最佳的分片计划。

为了确保准确考虑这些因素,Planner可以结合嵌入表、约束、硬件信息和拓扑数据,以帮助生成最佳计划。

嵌入表的分片

TorchRec分片器为各种使用场景提供了多种分片策略,我们概述了一些分片策略及其工作原理,以及它们的优点和局限性。通常,我们建议使用TorchRec规划器为您生成分片计划,因为它会为模型中的每个嵌入表找到最佳的分片策略。

每个分片策略决定了如何进行表的分割,是否应该切割表以及如何切割,是否保留某些表的一个或几个副本,等等。从分片结果中得到的表的每一部分,无论是一个嵌入表还是它的一部分,都被称为一个分片。

Visualizing the difference of sharding types offered in TorchRec

图1:可视化TorchRec中提供的不同分片方案下的表分片位置

以下是TorchRec中可用的所有分片类型列表:

  • 表级别(TW):顾名思义,嵌入表作为一个整体保留并放置在一个等级上。

  • 按列(CW):表格沿emb_dim维度分割, 例如,emb_dim=256被分割成4个分片:[64, 64, 64, 64]

  • 行方式(RW):表格沿着hash_size维度分割,通常在所有等级之间均匀分割。

  • 表级行级(TWRW):表放置在一个主机上,在该主机上的各个等级之间按行分割。

  • 网格分片(GS):一个表被CW分片,每个CW分片被放置在一台主机上,采用TWRW方式。

  • 数据并行(DP):每个等级都保留表的一个副本。

一旦分片,这些模块会被转换为它们的分片版本,在TorchRec中称为ShardedEmbeddingCollectionShardedEmbeddingBagCollection。这些模块处理输入数据的通信、嵌入查找和梯度。

使用TorchRec分片模块进行分布式训练

有许多分片策略可供选择,我们如何确定使用哪一种?每种分片方案都有相关的成本,结合模型大小和GPU数量,可以确定哪种分片策略最适合模型。

在没有分片的情况下,每个GPU都保留嵌入表的副本(DP),主要成本是计算,其中每个GPU在前向传递中查找其内存中的嵌入向量,并在后向传递中更新梯度。

使用分片技术时,会有一个额外的通信成本:每个GPU都需要向其他GPU请求嵌入向量查找,并传递计算出的梯度。这通常被称为all2all通信。在TorchRec中,对于给定GPU上的输入数据,我们确定数据的每个部分的嵌入分片位于何处,并将其发送到目标GPU。然后,目标GPU将嵌入向量返回给原始GPU。在反向传播过程中,梯度被发送回目标GPU,并使用优化器相应地更新分片。

如上所述,分片要求我们传递输入数据和嵌入查找。TorchRec 主要在三个阶段处理这个问题,我们将这称为用于 TorchRec 模型训练和推理的分片嵌入模块前向传播:

  • 功能 全部到全部/输入分布 (input_dist)

    • 将输入数据(以KeyedJaggedTensor的形式)传递到包含相关嵌入表分片的适当设备

  • 嵌入查找

    • 使用特征全交换后形成的新输入数据查找嵌入

  • 嵌入所有到所有/输出分布 (output_dist)

    • 将嵌入查找数据传回请求它的适当设备(根据设备接收到的输入数据)

  • 反向传播执行相同的操作,但顺序相反。

下图展示了它的工作原理:

Visualizing the forward pass including the input_dist, lookup, and output_dist of a sharded TorchRec module

图2:包括分片TorchRec模块的input_dist、lookup和output_dist的表格分片前向传递

分布式模型并行

以上所有内容最终汇聚成 TorchRec 用于分片和集成计划的主要入口点。在高层次上,DistributedModelParallel 执行以下操作:

  • 通过设置进程组和分配设备类型来初始化环境。

  • 如果没有提供着色器,则使用默认着色器,默认包括 EmbeddingBagCollectionSharder

  • 接受提供的分片计划,如果没有提供,则生成一个。

  • 创建模块的分片版本并替换原始模块,例如,将EmbeddingCollection转换为ShardedEmbeddingCollection

  • 默认情况下,将DistributedModelParallelDistributedDataParallel包装在一起,使模块同时具备模型和数据并行功能。

优化器

TorchRec模块提供了一个无缝的API,用于在训练中融合反向传播和优化器步骤,从而显著优化性能并减少内存使用,同时还可以为不同的模型参数分配不同的优化器。

Visualizing fusing of optimizer in backward to update sparse embedding table

图3:将嵌入反向与稀疏优化器融合

推理

推理环境与训练环境不同,它们对性能和模型的大小非常敏感。TorchRec推理优化的两个关键差异是:

  • 量化:推理模型经过量化以降低延迟并减少模型大小。这种优化使我们能够使用尽可能少的设备进行推理,以最小化延迟。

  • C++ 环境:为了进一步减少延迟,模型在 C++ 环境中运行。

TorchRec 提供了以下功能,用于将 TorchRec 模型转换为推理就绪状态:

  • 用于量化模型的API,包括使用FBGEMM TBE自动进行的优化

  • 用于分布式推理的分片嵌入

  • 将模型编译为TorchScript(与C++兼容)

另请参阅