SetTransformer解码器

class dgl.nn.pytorch.glob.SetTransformerDecoder(d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0.0, dropouta=0.0)[source]

Bases: Module

来自Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks的解码器模块

Parameters:
  • d_model (int) – 模型的隐藏大小。

  • num_heads (int) – 头的数量。

  • d_head (int) – 每个头的隐藏大小。

  • d_ff (int) – FFN(位置前馈网络)层中的核大小。

  • n_layers (int) – 层数。

  • k (int) – PMA(多头注意力池化)层中种子向量的数量。

  • dropouth (float) – 每个子层的丢弃率。

  • dropouta (float) – 注意力头的丢弃率。

示例

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SetTransformerDecoder
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> set_trans_dec = SetTransformerDecoder(5, 4, 4, 20, 1, 3)  # define the layer

案例1:输入单个图形

>>> set_trans_dec(g1, g1_node_feats)
tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,
          0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902]],
       grad_fn=<ViewBackward>)

案例2:输入一批图形

构建一批DGL图并将所有图的节点特征连接成一个张量。

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> set_trans_dec(batch_g, batch_f)
tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,
          0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902],
        [-0.5511,  1.8869, -1.0156,  0.0028, -0.3231, -0.6305,  1.6845, -1.3105,
          0.2136,  0.0428, -0.3820,  0.8043, -1.7138,  1.1126,  0.1789]],
       grad_fn=<ViewBackward>)

另请参阅

SetTransformerEncoder

forward(graph, feat)[source]

计算Set Transformer的解码器部分。

Parameters:
  • graph (DGLGraph) – The input graph.

  • 特征 (torch.Tensor) – 输入特征,形状为 \((N, D)\),其中 \(N\) 是图中节点的数量,\(D\) 表示特征的大小。

Returns:

输出特征的形状为 \((B, D)\),其中 \(B\) 指的是批量大小。

Return type:

torch.Tensor