异质图解释器

class dgl.nn.pytorch.explain.HeteroPGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]

基础类:PGExplainer

PGExplainer 来自 图神经网络的参数化解释器,适用于异构图

PGExplainer采用深度神经网络(解释网络)来参数化解释的生成过程,这使得它能够集体解释多个实例。PGExplainer将底层结构建模为边缘分布,从中采样出解释图。

Parameters:
  • model (nn.Module) –

    The GNN model to explain that tackles multiclass graph classification

    • Its forward function must have the form forward(self, graph, nfeat, embed, edge_weight).

    • The output of its forward function is the logits if embed=False else the intermediate node embeddings.

  • num_features (int) – Node embedding size used by model.

  • coff_budget (float, optional) – Size regularization to constrain the explanation size. Default: 0.01.

  • coff_connect (float, optional) – Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.

  • sample_bias (float, optional) – Some members of a population are systematically more likely to be selected in a sample than others. Default: 0.0.

explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边缘掩码,该掩码在解释GNN对图做出的预测中起着至关重要的作用。同时,返回基于边缘掩码选择的边缘所做出的预测。

Parameters:
  • graph (DGLGraph) – A heterogeneous graph.

  • feat (dict[str, Tensor]) – 一个将节点类型(键)映射到特征张量(值)的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是 节点类型 \(t\) 的节点数量,\(D_t\) 是 节点类型 \(t\) 的特征大小。

  • temperature (float) – The temperature parameter fed to the sampling procedure.

  • training (bool) – Training the explanation network.

  • kwargs (dict) – Additional arguments passed to the GNN model.

Returns:

  • Tensor – 给定掩码图的分类概率。它是一个形状为 \((B, L)\) 的张量,其中 \(L\) 是数据集中不同类型的标签,\(B\) 是批量大小。

  • dict[str, Tensor] – 一个字典,将边类型(键)映射到形状为 \((E_t)\) 的边张量(值),其中 \(E_t\) 是图中边类型 \(t\) 的边数。较高的权重表示边的贡献较大。

示例

>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
...         super().__init__()
...         self.conv = dgl.nn.HeteroGraphConv(
...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
...             aggregate="sum",
...         )
...         self.fc = nn.Linear(hid_feats, out_feats)
...         nn.init.xavier_uniform_(self.fc.weight)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         if edge_weight:
...             mod_kwargs = {
...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
...             }
...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
...         else:
...             h = self.conv(g, h)
...
...         if embed:
...             return h
...
...         with g.local_scope():
...             g.ndata["h"] = h
...             hg = 0
...             for ntype in g.ntypes:
...                 hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
...             return self.fc(hg)
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, g.ndata["h"])
...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
...     loss = explainer.train_step(g, g.ndata["h"], tmp)
...     optimizer_exp.zero_grad()
...     loss.backward()
...     optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat)
explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边缘掩码,该掩码在解释GNN为提供的节点ID集所做的预测中起着至关重要的作用。 同时,返回使用批处理图和边缘掩码所做的预测。

Parameters:
  • 节点 (dict[str, Iterable[int]]) – 一个将节点类型(键)映射到节点ID的可迭代集合(值)的字典。

  • graph (DGLGraph) – A heterogeneous graph.

  • feat (dict[str, Tensor]) – 一个将节点类型(键)映射到特征张量(值)的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是 节点类型 \(t\) 的节点数量,\(D_t\) 是 节点类型 \(t\) 的特征大小。

  • temperature (float) – The temperature parameter fed to the sampling procedure.

  • training (bool) – Training the explanation network.

  • kwargs (dict) – Additional arguments passed to the GNN model.

Returns:

  • dict[str, Tensor] – 一个字典,将节点类型(键)映射到节点标签的分类概率(值)。值是形状为\((N_t, L)\)的张量,其中\(L\)是数据集中不同类型的节点标签,\(N_t\)是图中节点类型\(t\)的节点数量。

  • dict[str, Tensor] – 一个字典,将边类型(键)映射到形状为\((E_t)\)的边张量(值),其中\(E_t\)是图中边类型\(t\)的边数量。较高的权重表示边的贡献较大。

  • DGLGraph – 输入中心节点的k跳入邻域上诱导的子图的批量集合。

  • dict[str, Tensor] – 一个字典,将节点类型(键)映射到节点ID的张量(值),这些节点ID对应于子图的中心节点。

示例

>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
...         super().__init__()
...         self.conv = dgl.nn.HeteroGraphConv(
...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
...             aggregate="sum",
...         )
...         self.fc = nn.Linear(hid_feats, out_feats)
...         nn.init.xavier_uniform_(self.fc.weight)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         if edge_weight:
...             mod_kwargs = {
...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
...             }
...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
...         else:
...             h = self.conv(g, h)
...
...         return h
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, g.ndata["h"])['user']
...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(
...     model, hidden_dim, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
...     loss = explainer.train_step_node(
...         { ntype: g.nodes(ntype) for ntype in g.ntypes },
...         g, g.ndata["h"], tmp
...     )
...     optimizer_exp.zero_grad()
...     loss.backward()
...     optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
...     { "user": [0] }, g, feat
... )
forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

train_step(graph, feat, temperature, **kwargs)[source]

计算图分类解释网络的损失

Parameters:
  • graph (DGLGraph) – 输入的批量异构图。

  • feat (dict[str, Tensor]) – 一个将节点类型(键)映射到特征张量(值)的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是 节点类型 \(t\) 的节点数量,\(D_t\) 是 节点类型 \(t\) 的特征大小。

  • temperature (float) – The temperature parameter fed to the sampling procedure.

  • kwargs (dict) – Additional arguments passed to the GNN model.

Returns:

表示损失的标量张量。

Return type:

张量

train_step_node(nodes, graph, feat, temperature, **kwargs)[source]

计算节点分类的解释网络的损失

Parameters:
  • 节点 (dict[str, Iterable[int]]) – 一个将节点类型(键)映射到节点ID的可迭代集合(值)的字典。

  • graph (DGLGraph) – 输入的异构图。

  • feat (dict[str, Tensor]) – 一个将节点类型(键)映射到特征张量(值)的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是 节点类型 \(t\) 的节点数量,\(D_t\) 是 节点类型 \(t\) 的特征大小。

  • temperature (float) – The temperature parameter fed to the sampling procedure.

  • kwargs (dict) – Additional arguments passed to the GNN model.

Returns:

表示损失的标量张量。

Return type:

张量