分布式节点分类

在本教程中,我们将逐步介绍如何为节点分类任务执行分布式GNN训练。要理解分布式GNN训练,您需要先阅读多GPU训练的教程。本教程是在多GPU训练的基础上开发的,通过提供额外的步骤来分割图、修改训练脚本以及设置分布式训练环境。

Partition a graph

在本教程中,我们将使用OGBN products图作为示例来说明图分区。首先,我们将图加载到DGL图中。在这里,我们将节点标签存储为DGL图中的节点数据。

import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import torch as th
from ogb.nodeproppred import DglNodePropPredDataset
data = DglNodePropPredDataset(name='ogbn-products')
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata['labels'] = labels

我们需要在图分区过程中将数据分为训练/验证/测试集。 因为这是一个节点分类任务,训练/验证/测试集包含节点ID。 我们建议用户将它们转换为布尔数组,其中True表示节点ID存在于集合中。 这样,我们可以将它们存储为节点数据。分区后,布尔数组将与图分区一起存储。

splitted_idx = data.get_idx_split()
train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
train_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
graph.ndata['train_mask'] = train_mask
graph.ndata['val_mask'] = val_mask
graph.ndata['test_mask'] = test_mask

然后我们调用partition_graph函数来使用METIS对图进行分区,并将分区结果保存在指定的文件夹中。注意partition_graph在单台机器上以单线程运行。您可以访问我们的用户指南以查看更多关于分布式图分区的信息。

下面的代码展示了一个调用分区算法并生成四个分区的示例。 分区结果存储在一个名为4part_data的文件夹中。在对图进行分区时, 我们允许用户指定如何平衡分区。默认情况下,算法会尽可能平衡每个分区中的节点数量。 然而,这种平衡策略对于分布式GNN训练来说是不够的,因为某些分区可能比其他分区拥有更多的训练节点, 或者某些分区可能比其他分区拥有更多的边。因此,partition_graph提供了两个额外的参数 balance_ntypesbalance_edges来强制执行更多的平衡标准。例如,我们可以使用训练掩码 来平衡每个分区中的训练节点数量,如下面的示例所示。我们还可以打开 balance_edges标志,以确保所有分区拥有大致相同数量的边。

dgl.distributed.partition_graph(graph, graph_name='ogbn-products', num_parts=4,
                                out_path='4part_data',
                                balance_ntypes=graph.ndata['train_mask'],
                                balance_edges=True)

在对图进行分区时,DGL会打乱节点ID和边ID,以便分配给分区的节点/边具有连续的ID。这对于DGL维护全局节点/边ID和分区ID的映射是必要的。如果用户需要将打乱的节点/边ID映射回其原始ID,他们可以打开partition_graphreturn_mapping标志,该标志返回节点ID映射和边ID映射的向量。下面展示了使用ID映射在分布式训练后保存节点嵌入的示例。这是用户在希望在下游任务中使用训练后的节点嵌入时的常见用例。下面我们假设训练后的节点嵌入存储在node_emb张量中,该张量由打乱的节点ID索引。我们再次打乱嵌入并将它们存储在orig_node_emb张量中,该张量由原始节点ID索引。

nmap, emap = dgl.distributed.partition_graph(graph, graph_name='ogbn-products',
                                             num_parts=4,
                                             out_path='4part_data',
                                             balance_ntypes=graph.ndata['train_mask'],
                                             balance_edges=True,
                                             return_mapping=True)
orig_node_emb = th.zeros(node_emb.shape, dtype=node_emb.dtype)
orig_node_emb[nmap] = node_emb

Distributed training script

分布式训练脚本与多GPU训练脚本非常相似,只需进行少量修改。 它还依赖于Pytorch分布式组件来交换梯度并更新模型参数。 分布式训练脚本仅包含训练器的代码。

初始化网络通信

分布式GNN训练需要访问分区的图结构以及节点/边的特征,同时还需要从多个训练器中聚合模型参数的梯度。DGL的分布式组件负责访问分布式图结构以及分布式节点特征和边特征,而Pytorch分布式则负责交换模型参数的梯度。因此,我们需要在训练脚本开始时初始化DGL和Pytorch的分布式组件。

我们需要在分布式训练脚本的最开始调用DGL的initialize函数来初始化训练器的网络通信并连接到DGL的服务器。这个函数有一个参数,用于接受集群配置文件的路径。

import dgl
import torch as th
dgl.distributed.initialize(ip_config='ip_config.txt')

The configuration file ip_config.txt has the following format:

ip_addr1 [port1]
ip_addr2 [port2]

每一行代表一台机器。第一列是IP地址,第二列是用于连接到该机器上的DGL服务器的端口。端口是可选的,默认端口是30050。

在初始化DGL的网络通信后,用户可以初始化Pytorch的分布式通信。

th.distributed.init_process_group(backend='gloo')

Reference to the distributed graph

DGL的服务器会自动加载图分区。服务器加载分区后,训练器连接到服务器,并可以开始引用集群中的分布式图,如下所示。

g = dgl.distributed.DistGraph('ogbn-products')

As shown in the code, we refer to a distributed graph by its name. This name is basically the one passed to the partition_graph function as shown in the section above.

Get training and validation node IDs

对于分布式训练,每个训练器可以运行自己的一组训练节点。 整个图的训练节点存储在分布式张量中,作为train_mask节点数据, 这在我们分割图之前就已经构建好了。每个训练器可以调用node_split来获取其训练节点集。 node_split函数将完整的训练集均匀分割,并返回训练节点,其中大部分存储在本地分区中,以确保良好的数据局部性。

train_nid = dgl.distributed.node_split(g.ndata['train_mask'])

我们可以像上面那样分割验证节点。在这种情况下,每个训练器都会得到一组不同的验证节点。

valid_nid = dgl.distributed.node_split(g.ndata['val_mask'])

Define a GNN model

对于分布式训练,我们定义GNN模型的方式与 mini-batch训练全图训练完全相同。 下面的代码定义了GraphSage模型。

import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import torch.optim as optim

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))

    def forward(self, blocks, x):
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            x = layer(block, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

num_hidden = 256
num_labels = len(th.unique(g.ndata['labels'][0:g.num_nodes()]))
num_layers = 2
lr = 0.001
model = SAGE(g.ndata['feat'].shape[1], num_hidden, num_labels, num_layers)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

For distributed training, we need to convert the model into a distributed model with Pytorch’s DistributedDataParallel.

model = th.nn.parallel.DistributedDataParallel(model)

Distributed mini-batch sampler

我们可以使用相同的DistNodeDataLoader,它是NodeDataLoader的分布式对应物,来创建一个用于节点分类的分布式小批量采样器。

sampler = dgl.dataloading.MultiLayerNeighborSampler([25,10])
train_dataloader = dgl.dataloading.DistNodeDataLoader(
                             g, train_nid, sampler, batch_size=1024,
                             shuffle=True, drop_last=False)
valid_dataloader = dgl.dataloading.DistNodeDataLoader(
                             g, valid_nid, sampler, batch_size=1024,
                             shuffle=False, drop_last=False)

Training loop

分布式训练的训练循环也与单进程训练完全相同。

import sklearn.metrics
import numpy as np

for epoch in range(10):
    # Loop over the dataloader to sample mini-batches.
    losses = []
    with model.join():
        for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
            # Load the input features as well as output labels
            batch_inputs = g.ndata['feat'][input_nodes]
            batch_labels = g.ndata['labels'][seeds]

            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            losses.append(loss.detach().cpu().numpy())
            optimizer.step()

    # validation
    predictions = []
    labels = []
    with th.no_grad(), model.join():
        for step, (input_nodes, seeds, blocks) in enumerate(valid_dataloader):
            inputs = g.ndata['feat'][input_nodes]
            labels.append(g.ndata['labels'][seeds].numpy())
            predictions.append(model(blocks, inputs).argmax(1).numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {}: Validation Accuracy {}'.format(epoch, accuracy))

设置分布式训练环境

在分区图并准备训练脚本之后,我们现在需要设置分布式训练环境并启动训练任务。基本上,我们需要创建一个机器集群,并将训练脚本和分区数据上传到集群中的每台机器上。推荐的在集群中共享训练脚本和分区数据的解决方案是使用NFS(网络文件系统)。

对于任何不熟悉NFS的用户,以下是在现有集群中设置NFS的小教程。

NFS 服务器端设置(仅限 Ubuntu)

首先,在存储服务器上安装必要的库

sudo apt-get install nfs-kernel-server

下面我们假设用户账户是ubuntu,并且我们在主目录中创建了一个工作区目录。

mkdir -p /home/ubuntu/workspace

我们假设所有服务器都在一个IP范围为192.168.0.0到192.168.255.255的子网下。 我们需要将以下行添加到/etc/exports

/home/ubuntu/workspace  192.168.0.0/16(rw,sync,no_subtree_check)

然后重新启动NFS,服务器端的设置就完成了。

sudo systemctl restart nfs-kernel-server

有关配置详情,请参考NFS ArchWiki (https://wiki.archlinux.org/index.php/NFS)。

NFS客户端设置(仅限Ubuntu)

要使用NFS,客户端还需要安装必要的软件包

sudo apt-get install nfs-common

您可以手动挂载NFS

mkdir -p /home/ubuntu/workspace
sudo mount -t nfs <nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace

或者将以下行添加到/etc/fstab中,以便文件夹可以自动挂载

<nfs-server-ip>:/home/ubuntu/workspace   /home/ubuntu/workspace   nfs   defaults    0 0

然后运行

mount -a

现在转到/home/ubuntu/workspace并将训练脚本和分区数据保存在文件夹中。

SSH访问

启动脚本通过SSH访问集群中的机器。用户应按照此文档中的说明在集群中的每台机器上设置无密码SSH登录。设置无密码SSH后,用户需要验证与每台机器的连接,并将其密钥指纹添加到~/.ssh/known_hosts中。这可以在我们首次SSH到机器时自动完成。

启动分布式训练任务

一切准备就绪后,我们现在可以使用DGL提供的启动脚本来启动集群中的分布式训练任务。我们可以在集群中的任何机器上运行该启动脚本。

python3 ~/workspace/dgl/tools/launch.py   --workspace ~/workspace/   --num_trainers 1   --num_samplers 0   --num_servers 1   --part_config 4part_data/ogbn-products.json   --ip_config ip_config.txt   "python3 train_dist.py"

如果我们将图分成四个分区,如教程开头所示,集群必须有四台机器。上述命令在集群中的每台机器上启动一个训练器和一个服务器。ip_config.txt列出了集群中所有机器的IP地址,如下所示:

ip_addr1
ip_addr2
ip_addr3
ip_addr4

使用GraphBolt进行邻居采样

自 DGL 2.0 以来,我们引入了一个新的数据加载框架 GraphBolt, 其中采样相比 DGL 之前的实现有了显著改进。 因此,我们将 GraphBolt 引入到分布式训练中,以提高 分布式采样的性能。此外,图分区可以比以前小得多,这对 分布式训练期间的加载速度和内存使用都有好处。

图分区

为了从GraphBolt中受益以进行分布式采样,我们需要将分区从DGL格式转换为GraphBolt格式。这可以通过dgl.distributed.dgl_partition_to_graphbolt函数完成。或者,我们可以使用dgl.distributed.partition_graph函数直接生成GraphBolt格式的分区。

  1. 将分区从DGL格式转换为GraphBolt格式。

part_config = "4part_data/ogbn-products.json"
dgl.distributed.dgl_partition_to_graphbolt(part_config)

新分区将存储在与原始分区相同的目录中。

2. 直接在GraphBolt格式中生成分区。只需在partition_graph函数中将use_graphbolt标志设置为True

dgl.distributed.partition_graph(graph, graph_name='ogbn-products', num_parts=4,
                                out_path='4part_data',
                                balance_ntypes=graph.ndata['train_mask'],
                                balance_edges=True,
                                use_graphbolt=True)

在训练脚本中启用GraphBolt采样

只需在dgl.distributed.initialize函数中将use_graphbolt标志设置为True。这是在训练脚本中启用GraphBolt采样所需的唯一更改。

dgl.distributed.initialize('ip_config.txt', use_graphbolt=True)

Total running time of the script: (0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery