DeepWalk

class dgl.nn.pytorch.DeepWalk(g, emb_dim=128, walk_length=40, window_size=5, neg_weight=1, negative_size=5, fast_neg=True, sparse=True)[source]

Bases: Module

来自DeepWalk: Online Learning of Social Representations的DeepWalk模块

对于一个图,它通过最大化附近节点对(正节点对)的相似性并最小化其他随机节点对(负节点对)的相似性,从零开始学习节点表示。

Parameters:
  • g (DGLGraph) – 用于学习节点嵌入的图

  • emb_dim (int, 可选) – 每个嵌入向量的大小。默认值:128

  • walk_length (int, optional) – 随机游走序列中的节点数。默认值:40

  • window_size (int, optional) – 在随机游走 w 中,节点 w[j] 被认为是接近节点 w[i] 如果 i - window_size <= j <= i + window_size。默认值:5

  • neg_weight (float, optional) – 负样本在总损失中的损失项的权重。默认值:1.0

  • negative_size (int, optional) – 每个正样本使用的负样本数量。默认值:5

  • fast_neg (bool, 可选) – 如果为True,则在随机游走的一批中采样负节点对。默认值:True

  • sparse (bool, optional) – 如果为True,关于可学习权重的梯度将是稀疏的。 默认值:True

node_embed

节点的嵌入表

Type:

nn.Embedding

示例

>>> import torch
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import DeepWalk
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from sklearn.linear_model import LogisticRegression
>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> model = DeepWalk(g)
>>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,
...                         shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> num_epochs = 5
>>> for epoch in range(num_epochs):
...     for batch_walk in dataloader:
...         loss = model(batch_walk)
...         optimizer.zero_grad()
...         loss.backward()
...         optimizer.step()
>>> train_mask = g.ndata['train_mask']
>>> test_mask = g.ndata['test_mask']
>>> X = model.node_embed.weight.detach()
>>> y = g.ndata['label']
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
forward(batch_walk)[source]

计算一批随机游走的损失

Parameters:

batch_walk (torch.Tensor) – 以节点ID序列形式表示的随机游走。该张量的形状为 (batch_size, walk_length)

Returns:

损失值

Return type:

torch.Tensor

reset_parameters()[source]

重新初始化可学习参数