多层全邻居采样器
- class dgl.dataloading.MultiLayerFullNeighborSampler(num_layers, **kwargs)[source]
基础类:
NeighborSampler
采样器通过从多层GNN的所有邻居获取消息来构建节点表示的计算依赖关系。
此采样器将使每个节点从每种边类型的每个邻居处收集消息。
- Parameters:
num_layers (int) – 要采样的GNN层数。
kwargs – 传递给
dgl.dataloading.NeighborSampler
。
示例
为了在一组节点
train_nid
上训练一个3层GNN进行节点分类,在一个同质图上,每个节点分别从第一层、第二层和第三层的所有邻居接收消息(假设后端是PyTorch):>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_nid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, output_nodes, blocks in dataloader: ... train_on(blocks)
注释
For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.