torch_geometric.transforms.AddMetaPaths

class AddMetaPaths(metapaths: List[List[Tuple[str, str, str]]], drop_orig_edge_types: bool = False, keep_same_node_type: bool = False, drop_unconnected_node_types: bool = False, max_sample: Optional[int] = None, weighted: bool = False, **kwargs: bool)[source]

Bases: BaseTransform

HeteroData对象添加额外的边类型,这些边类型位于给定metapath的源节点类型和目标节点类型之间,如“异构图注意力网络”论文中所述(功能名称:add_metapaths)。

基于元路径的邻居可以利用异质图中不同方面的结构信息。 形式上,元路径是以下形式的路径

\[\mathcal{V}_1 \xrightarrow{R_1} \mathcal{V}_2 \xrightarrow{R_2} \ldots \xrightarrow{R_{\ell-1}} \mathcal{V}_{\ell}\]

其中 \(\mathcal{V}_i\) 表示节点类型,\(R_j\) 表示连接两个节点类型的边类型。 添加的边类型由沿元路径的邻接矩阵的连续乘法给出,并作为边类型添加到 HeteroData 对象中,边类型为 (src_node_type, "metapath_*", dst_node_type),其中 src_node_typedst_node_type 分别表示 \(\mathcal{V}_1\)\(\mathcal{V}_{\ell}\)

此外,一个metapath_dict对象被添加到 HeteroData对象中,该对象将基于元路径的边类型映射到其原始元路径。

from torch_geometric.datasets import DBLP
from torch_geometric.data import HeteroData
from torch_geometric.transforms import AddMetaPaths

data = DBLP(root)[0]
# 4 node types: "paper", "author", "conference", and "term"
# 6 edge types: ("paper","author"), ("author", "paper"),
#               ("paper, "term"), ("paper", "conference"),
#               ("term, "paper"), ("conference", "paper")

# Add two metapaths:
# 1. From "paper" to "paper" through "conference"
# 2. From "author" to "conference" through "paper"
metapaths = [[("paper", "conference"), ("conference", "paper")],
             [("author", "paper"), ("paper", "conference")]]
data = AddMetaPaths(metapaths)(data)

print(data.edge_types)
>>> [("author", "to", "paper"), ("paper", "to", "author"),
     ("paper", "to", "term"), ("paper", "to", "conference"),
     ("term", "to", "paper"), ("conference", "to", "paper"),
     ("paper", "metapath_0", "paper"),
     ("author", "metapath_1", "conference")]

print(data.metapath_dict)
>>> {("paper", "metapath_0", "paper"): [("paper", "conference"),
                                        ("conference", "paper")],
     ("author", "metapath_1", "conference"): [("author", "paper"),
                                              ("paper", "conference")]}
Parameters:
  • metapaths (List[List[Tuple[str, str, str]]]) – 由一系列(src_node_type, rel_type, dst_node_type)元组组成的列表描述的metapaths。

  • drop_orig_edge_types (bool, optional) – 如果设置为 True,现有的边类型将被删除。(默认值:False

  • keep_same_node_type (bool, 可选) – 如果设置为 True,即使 drop_orig_edge_types 设置为 True,相同节点类型之间的现有边类型也不会被删除。 (默认: False)

  • drop_unconnected_node_types (bool, optional) – 如果设置为 True, 将删除未通过任何边类型连接的节点类型。 (默认: False)

  • max_sample (int, optional) – 如果设置,将在元路径内最多采样 max_sample 个邻居。这对于处理非常密集的元路径边非常有用。(默认值:None

  • weighted (bool, 可选) – 如果设置为 True,则为每个元路径边计算权重并将其存储在 edge_weight 中。每个元路径边的权重计算为从元路径边的起点到终点的元路径数量。(默认 False