异质子图X

class dgl.nn.pytorch.explain.HeteroSubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]

Bases: Module

SubgraphX 来自 关于通过子图探索解释图神经网络的可解释性,适用于异构图

它从原始图中识别出最重要的子图,该子图在基于GNN的图分类中起着关键作用。

它采用蒙特卡洛树搜索(MCTS)来高效探索不同的子图以进行解释,并使用Shapley值作为子图重要性的度量。

Parameters:
  • model (nn.Module) –

    The GNN model to explain that tackles multiclass graph classification

    • Its forward function must have the form forward(self, graph, nfeat).

    • The output of its forward function is the logits.

  • num_hops (int) – Number of message passing layers in the model

  • coef (float, optional) – This hyperparameter controls the trade-off between exploration and exploitation. A higher value encourages the algorithm to explore relatively unvisited nodes. Default: 10.0

  • high2low (bool, optional) – If True, it will use the “High2low” strategy for pruning actions, expanding children nodes from high degree to low degree when extending the children nodes in the search tree. Otherwise, it will use the “Low2high” strategy. Default: True

  • num_child (int, optional) – This is the number of children nodes to expand when extending the children nodes in the search tree. Default: 12

  • num_rollouts (int, optional) – This is the number of rollouts for MCTS. Default: 20

  • node_min (int, optional) – This is the threshold to define a leaf node based on the number of nodes in a subgraph. Default: 3

  • shapley_steps (int, optional) – This is the number of steps for Monte Carlo sampling in estimating Shapley values. Default: 100

  • log (bool, optional) – If True, it will log the progress. Default: False

explain_graph(graph, feat, target_class, **kwargs)[source]

从原始图中找到最重要的子图,以便模型将图分类为目标类别。

Parameters:
  • graph (DGLGraph) – 一个异构图

  • feat (dict[str, Tensor]) – 将输入节点特征(值)与图中存在的相应节点类型(键)关联的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数量,\(D_t\) 是节点类型 \(t\) 的特征大小。

  • target_class (int) – The target class to explain

  • kwargs (dict) – Additional arguments passed to the GNN model

Returns:

将张量节点ID(值)与节点类型(键)关联的字典,表示最重要的子图

Return type:

dict[str, Tensor]

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroSubgraphX
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict(
...             {
...                 "_".join(c_etype): nn.Linear(in_dim, num_classes)
...                 for c_etype in canonical_etypes
...             }
...         )
...
...     def forward(self, graph, feat):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights["_".join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f"h_{c_etype}"] = wh
...                 c_etype_func_dict[c_etype] = (
...                     fn.copy_u(f"h_{c_etype}", "m"),
...                     fn.mean("m", "h"),
...                 )
...             graph.multi_update_all(c_etype_func_dict, "sum")
...             hg = 0
...             for ntype in graph.ntypes:
...                 if graph.num_nodes(ntype):
...                     hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
...             return hg
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata["h"]
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)
...     loss = F.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain for the graph
>>> explainer = HeteroSubgraphX(model, num_hops=1)
>>> explainer.explain_graph(g, feat, target_class=1)
{'game': tensor([0, 1]), 'user': tensor([1, 2])}
forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.