多层全邻居采样器

class dgl.dataloading.MultiLayerFullNeighborSampler(num_layers, **kwargs)[source]

基础类:NeighborSampler

采样器通过从多层GNN的所有邻居获取消息来构建节点表示的计算依赖关系。

此采样器将使每个节点从每种边类型的每个邻居处收集消息。

Parameters:

示例

为了在一组节点train_nid上训练一个3层GNN进行节点分类,在一个同质图上,每个节点分别从第一层、第二层和第三层的所有邻居接收消息(假设后端是PyTorch):

>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)

注释

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.