节点嵌入
- class dgl.nn.pytorch.sparse_emb.NodeEmbedding(num_embeddings, embedding_dim, name, init_func=None, device=None, partition=None)[source]
Bases:
object
用于存储节点嵌入的类。
该类针对训练大规模节点嵌入进行了优化。它以稀疏方式更新嵌入,并且可以扩展到具有数百万个节点的图。它还支持将分区到多个GPU(在单台机器上)以进一步加速。它不支持跨机器的分区。
目前,DGL 提供了两种与此 NodeEmbedding 类一起使用的优化器:
SparseAdagrad
和SparseAdam
。该实现基于torch.distributed包。它依赖于pytorch默认的分布式进程组来收集多进程信息,并使用
torch.distributed.TCPStore
在多个GPU进程之间共享元数据信息。它使用本地地址'127.0.0.1:12346'来初始化TCPStore。注意:NodeEmbedding的支持是实验性的。
- Parameters:
示例
在启动多个GPU进程之前
>>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb
在每个训练过程中
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step()