异质嵌入

class dgl.nn.pytorch.HeteroEmbedding(num_embeddings, embedding_dim)[source]

Bases: Module

创建一个异构嵌入表。

它内部包含多个具有不同字典大小的torch.nn.Embedding

Parameters:
  • num_embeddings (dict[key, int]) – 字典的大小。键可以是字符串或字符串的元组。

  • embedding_dim (int) – 每个嵌入向量的大小。

示例

>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
>>> # Get the heterogeneous embedding table
>>> embeds = layer.weight
>>> print(embeds['user'].shape)
torch.Size([2, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([3, 4])
>>> # Get the embeddings for a subset
>>> input_ids = {'user': torch.LongTensor([0]),
...              ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
>>> embeds = layer(input_ids)
>>> print(embeds['user'].shape)
torch.Size([1, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
forward(input_ids)[source]

前向函数

Parameters:

input_ids (dict[key, Tensor]) – 用于检索嵌入的行ID。它将一个键映射到特定键的ID。

Returns:

检索到的嵌入。

Return type:

dict[key, Tensor]

reset_parameters()[source]

使用nn.init模块中的xavier方法使参数均匀分布