torch_geometric.transforms.AddMetaPaths
- class AddMetaPaths(metapaths: List[List[Tuple[str, str, str]]], drop_orig_edge_types: bool = False, keep_same_node_type: bool = False, drop_unconnected_node_types: bool = False, max_sample: Optional[int] = None, weighted: bool = False, **kwargs: bool)[source]
Bases:
BaseTransform向
HeteroData对象添加额外的边类型,这些边类型位于给定metapath的源节点类型和目标节点类型之间,如“异构图注意力网络”论文中所述(功能名称:add_metapaths)。基于元路径的邻居可以利用异质图中不同方面的结构信息。 形式上,元路径是以下形式的路径
\[\mathcal{V}_1 \xrightarrow{R_1} \mathcal{V}_2 \xrightarrow{R_2} \ldots \xrightarrow{R_{\ell-1}} \mathcal{V}_{\ell}\]其中 \(\mathcal{V}_i\) 表示节点类型,\(R_j\) 表示连接两个节点类型的边类型。 添加的边类型由沿元路径的邻接矩阵的连续乘法给出,并作为边类型添加到
HeteroData对象中,边类型为(src_node_type, "metapath_*", dst_node_type),其中src_node_type和dst_node_type分别表示 \(\mathcal{V}_1\) 和 \(\mathcal{V}_{\ell}\)。此外,一个
metapath_dict对象被添加到HeteroData对象中,该对象将基于元路径的边类型映射到其原始元路径。from torch_geometric.datasets import DBLP from torch_geometric.data import HeteroData from torch_geometric.transforms import AddMetaPaths data = DBLP(root)[0] # 4 node types: "paper", "author", "conference", and "term" # 6 edge types: ("paper","author"), ("author", "paper"), # ("paper, "term"), ("paper", "conference"), # ("term, "paper"), ("conference", "paper") # Add two metapaths: # 1. From "paper" to "paper" through "conference" # 2. From "author" to "conference" through "paper" metapaths = [[("paper", "conference"), ("conference", "paper")], [("author", "paper"), ("paper", "conference")]] data = AddMetaPaths(metapaths)(data) print(data.edge_types) >>> [("author", "to", "paper"), ("paper", "to", "author"), ("paper", "to", "term"), ("paper", "to", "conference"), ("term", "to", "paper"), ("conference", "to", "paper"), ("paper", "metapath_0", "paper"), ("author", "metapath_1", "conference")] print(data.metapath_dict) >>> {("paper", "metapath_0", "paper"): [("paper", "conference"), ("conference", "paper")], ("author", "metapath_1", "conference"): [("author", "paper"), ("paper", "conference")]}
- Parameters:
metapaths (List[List[Tuple[str, str, str]]]) – 由一系列
(src_node_type, rel_type, dst_node_type)元组组成的列表描述的metapaths。drop_orig_edge_types (bool, optional) – 如果设置为
True,现有的边类型将被删除。(默认值:False)keep_same_node_type (bool, 可选) – 如果设置为
True,即使drop_orig_edge_types设置为True,相同节点类型之间的现有边类型也不会被删除。 (默认:False)drop_unconnected_node_types (bool, optional) – 如果设置为
True, 将删除未通过任何边类型连接的节点类型。 (默认:False)max_sample (int, optional) – 如果设置,将在元路径内最多采样
max_sample个邻居。这对于处理非常密集的元路径边非常有用。(默认值:None)weighted (bool, 可选) – 如果设置为
True,则为每个元路径边计算权重并将其存储在edge_weight中。每个元路径边的权重计算为从元路径边的起点到终点的元路径数量。(默认False)