torch_geometric.nn

class Sequential(input_args: str, modules: List[Union[Tuple[Callable, str], Callable]])[source]

torch.nn.Sequential 容器的扩展,用于定义一个顺序的 GNN 模型。

由于GNN操作符接受多个输入参数, torch_geometric.nn.Sequential 还期望有全局输入参数和各个操作符的函数头定义。 如果省略,中间模块将操作其前一个模块的输出

from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv

model = Sequential('x, edge_index', [
    (GCNConv(in_channels, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, out_channels),
])

在这里,'x, edge_index' 定义了 model 的输入参数, 而 'x, edge_index -> x' 定义了函数头, 输入 参数 返回类型 GCNConv

特别是,这也允许创建更复杂的模型,例如利用JumpingKnowledge

from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool

model = Sequential('x, edge_index, batch', [
    (Dropout(p=0.5), 'x -> x'),
    (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    ReLU(inplace=True),
    (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
    (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'),
    (global_mean_pool, 'x, batch -> x'),
    Linear(2 * 64, dataset.num_classes),
])
Parameters:
  • input_args (str) – 模型的输入参数。

  • 模块 ([(Callable, str) 或 Callable]) – 一个模块列表(带有可选的函数头定义)。或者,可以传递一个OrderedDict的模块(和函数头定义)。

class Linear(in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None)[source]

对输入数据应用线性变换。

\[\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}\]

torch.nn.Linear相比,它支持延迟初始化以及可定制的权重和偏置初始化。

Parameters:
  • in_channels (int) – 每个输入样本的大小。如果给定为 -1,则会延迟初始化。

  • out_channels (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • weight_initializer (str, optional) – 权重矩阵的初始化器 ("glorot", "uniform", "kaiming_uniform"None). 如果设置为 None,将匹配 torch.nn.Linear 的默认权重初始化。(默认值: None)

  • bias_initializer (str, optional) – 偏置向量的初始化器 ("zeros"None)。 如果设置为 None,将匹配 torch.nn.Linear 的默认偏置初始化。(默认值: None)

Shapes:
  • 输入: 特征 \((*, F_{in})\)

  • 输出: 特征 \((*, F_{out})\)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor) Tensor[source]

前向传播。

Parameters:

x (torch.Tensor) – The input features.

Return type:

Tensor

class HeteroLinear(in_channels: int, out_channels: int, num_types: int, is_sorted: bool = False, **kwargs)[source]

根据类型对传入的数据应用单独的线性变换。

对于类型 \(\kappa\),它计算

\[\mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}.\]

它支持延迟初始化和可定制的权重和偏置初始化。

Parameters:
  • in_channels (int) – 每个输入样本的大小。如果给定为 -1,则会延迟初始化。

  • out_channels (int) – Size of each output sample.

  • num_types (int) – The number of types.

  • is_sorted (bool, optional) – 如果设置为 True,则假定 type_vec 已排序。这可以避免内部重新排序数据,从而提高运行时和内存效率。 (默认值: False)

  • **kwargs (可选) – torch_geometric.nn.Linear 的额外参数。

Shapes:
  • 输入: 特征 \((*, F_{in})\), 类型向量 \((*)\)

  • output: features \((*, F_{out})\)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor, type_vec: Tensor) Tensor[source]

前向传播。

Parameters:
Return type:

Tensor

class HeteroDictLinear(in_channels: Union[int, Dict[Any, int]], out_channels: int, types: Optional[Any] = None, **kwargs)[source]

对传入的数据字典应用单独的线性变换。

对于键 \(\kappa\),它计算

\[\mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}.\]

它支持延迟初始化和可定制的权重和偏置初始化。

Parameters:
  • in_channels (intDict[Any, int]) – 每个输入样本的大小。如果传递了一个整数,types 将是一个必需的参数。如果给定为 -1,则会延迟初始化。

  • out_channels (int) – Size of each output sample.

  • types (List[Any], 可选) – 输入字典的键。 (默认: None)

  • **kwargs (可选) – torch_geometric.nn.Linear 的额外参数。

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x_dict: Dict[str, Tensor]) Dict[str, Tensor][source]

前向传播。

Parameters:

x_dict (Dict[Any, torch.Tensor]) – 一个字典,包含每种类型的输入特征。

Return type:

Dict[str, Tensor]

卷积层

MessagePassing

用于创建消息传递层的基类。

SimpleConv

一个简单的消息传递操作符,执行(不可训练的)传播。

GCNConv

来自"Semi-supervised Classification with Graph Convolutional Networks"论文的图卷积操作符。

ChebConv

来自"Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering"论文的切比雪夫谱图卷积算子。

SAGEConv

来自"Inductive Representation Learning on Large Graphs"论文的GraphSAGE操作符。

CuGraphSAGEConv

来自"Inductive Representation Learning on Large Graphs"论文的GraphSAGE操作符。

GraphConv

来自"Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks"论文的图神经网络操作符。

GravNetConv

来自"Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks"论文的GravNet操作符,其中图是使用最近邻动态构建的。

GatedGraphConv

来自"Gated Graph Sequence Neural Networks"论文的门控图卷积操作符。

ResGatedGraphConv

来自"Residual Gated Graph ConvNets"论文的残差门控图卷积算子。

GATConv

来自"Graph Attention Networks"论文的图注意力操作符。

CuGraphGATConv

来自"Graph Attention Networks"论文的图注意力操作符。

FusedGATConv

来自"理解GNN计算图:从计算、IO和内存的协调视角"论文的融合图注意力操作符。

GATv2Conv

来自"How Attentive are Graph Attention Networks?"论文的GATv2操作符,它修复了标准GATConv层的静态注意力问题。

TransformerConv

来自"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification"论文的图变换器操作符。

AGNNConv

来自"基于注意力的图神经网络用于半监督学习"论文的图注意力传播层。

TAGConv

来自"Topology Adaptive Graph Convolutional Networks"论文的拓扑自适应图卷积网络算子。

GINConv

来自"图神经网络有多强大?"论文的图同构算子。

GINEConv

修改后的GINConv操作符来自"Strategies for Pre-training Graph Neural Networks"论文。

ARMAConv

来自"Graph Neural Networks with Convolutional ARMA Filters"论文的ARMA图卷积算子。

SGConv

来自"Simplifying Graph Convolutional Networks"论文的简单图卷积操作符。

SSGConv

来自"Simple Spectral Graph Convolution"论文的简单谱图卷积算子。

APPNP

来自"Predict then Propagate: Graph Neural Networks meet Personalized PageRank"论文的近似个性化传播神经预测层。

MFConv

来自"Convolutional Networks on Graphs for Learning Molecular Fingerprints"论文的图神经网络操作符。

RGCNConv

来自"使用图卷积网络建模关系数据"论文的关系图卷积操作符。

FastRGCNConv

See RGCNConv.

CuGraphRGCNConv

来自"使用图卷积网络建模关系数据"论文的关系图卷积操作符。

RGATConv

来自"Relational Graph Attention Networks"论文的关系图注意力操作符。

SignedConv

来自"Signed Graph Convolutional Network"论文的带符号图卷积操作符。

DNAConv

来自"Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks"论文的动态邻域聚合操作符。

PointNetConv

来自"PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation""PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space"论文的PointNet集合层。

GMMConv

来自"Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs"论文的高斯混合模型卷积算子。

SplineConv

来自"SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels"论文的基于样条的卷积算子。

NNConv

来自"Neural Message Passing for Quantum Chemistry"论文的基于连续核的卷积操作符。

CGConv

来自"Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties"论文的晶体图卷积算子。

EdgeConv

边缘卷积算子来自"Dynamic Graph CNN for Learning on Point Clouds"论文。

DynamicEdgeConv

来自"Dynamic Graph CNN for Learning on Point Clouds"论文的动态边缘卷积算子(参见torch_geometric.nn.conv.EdgeConv),其中图是使用特征空间中的最近邻动态构建的。

XConv

卷积运算符作用于来自"PointCNN: Convolution On X-Transformed Points"论文的\(\mathcal{X}\)变换点。

PPFConv

来自"PPFNet: Global Context Aware Local Features for Robust 3D Point Matching"论文的PPFNet操作符。

FeaStConv

来自"FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis"论文的(平移不变)特征导向卷积算子。

PointTransformerConv

来自"Point Transformer"论文的Point Transformer层。

HypergraphConv

来自"超图卷积和超图注意力"论文的超图卷积算子。

LEConv

来自"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations"论文的局部极值图神经网络算子。

PNAConv

来自"Principal Neighbourhood Aggregation for Graph Nets"论文的主邻域聚合图卷积算子。

ClusterGCNConv

来自"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks"论文的ClusterGCN图卷积算子。

GENConv

来自"DeeperGCN: All You Need to Train Deeper GCNs"论文的广义图卷积(GENConv)。

GCN2Conv

来自"Simple and Deep Graph Convolutional Networks"论文的具有初始残差连接和恒等映射的图卷积算子(GCNII)。

PANConv

基于路径积分的卷积算子来自"Path Integral Based Convolution and Pooling for Graph Neural Networks"论文。

WLConv

来自"A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction"论文的Weisfeiler Lehman (WL)算子。

WLConvContinuous

来自"Wasserstein Weisfeiler-Lehman Graph Kernels"论文的Weisfeiler Lehman算子。

FiLMConv

来自"GNN-FiLM: 具有特征线性调制的图神经网络"论文的FiLM图卷积操作符。

SuperGATConv

来自"如何找到你的友好邻居:自监督图注意力设计"论文的自监督图注意力操作符。

FAConv

来自"Beyond Low-Frequency Information in Graph Convolutional Networks"论文的频率自适应图卷积算子。

EGConv

来自"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions"论文的高效图卷积。

PDNConv

来自"Pathfinder Discovery Networks for Neural Message Passing"论文的路径发现网络卷积算子。

GeneralConv

一个通用的GNN层,改编自"图神经网络的设计空间"论文。

HGTConv

来自"异构图变换器"论文的异构图变换器(HGT)操作符。

HEATConv

来自"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction"论文的异构边缘增强图注意力操作符。

HeteroConv

用于在异质图上计算图卷积的通用包装器。

HANConv

来自"异构图注意力网络"论文的异构图注意力操作符。

LGConv

来自"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation"论文的轻量图卷积(LGC)操作符。

PointGNNConv

来自"Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud"论文的PointGNN操作符。

GPSConv

来自"Recipe for a General, Powerful, Scalable Graph Transformer"论文的通用、强大、可扩展(GPS)图变换器层。

AntiSymmetricConv

来自"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks"论文的反对称图卷积算子。

DirGNNConv

一个通用的包装器,用于在有向图上计算图卷积,如"Edge Directionality Improves Learning on Heterophilic Graphs"论文中所述。

MixHopConv

来自"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing"论文的Mix-Hop图卷积操作符。

聚合操作符

聚合函数在图神经网络的消息传递框架和读出函数中扮演着重要角色。 具体来说,许多文献中的工作(Hamilton et al. (2017), Xu et al. (2018), Corso et al. (2020), Li et al. (2020), Tailor et al. (2021))表明,聚合函数的选择对模型的表示能力和性能有显著影响。 例如,均值聚合捕捉元素的分布(或比例),最大聚合被证明在识别代表性元素方面具有优势,而求和聚合则能够学习图的结构属性(Xu et al. (2018))。 最近的研究还表明,使用多重聚合Corso et al. (2020), Tailor et al. (2021))和可学习聚合Li et al. (2020))可能会带来显著的改进。 另一项研究则探讨了基于优化和隐式定义的聚合(Bartunov et al. (2022))。 此外,一个有趣的讨论涉及表示能力(通常通过实现为神经网络的可学习函数获得)与排列不变性的形式属性之间的权衡(Buterez et al. (2022))。

为了促进进一步的实验并统一GNN中聚合的概念,无论是在MessagePassing还是全局读取中,我们已将Aggregation的概念作为中的首要原则。 截至目前,提供了对各种聚合的支持——从较为简单的(例如meanmaxsum),到高级的(例如medianvarstd),可学习的(例如SoftmaxAggregationPowerMeanAggregationSetTransformerAggregation),以及奇特的(例如MLPAggregationLSTMAggregationSortAggregationEquilibriumAggregation):

from torch_geometric.nn import aggr

# Simple aggregations:
mean_aggr = aggr.MeanAggregation()
max_aggr = aggr.MaxAggregation()

# Advanced aggregations:
median_aggr = aggr.MedianAggregation()

# Learnable aggregations:
softmax_aggr = aggr.SoftmaxAggregation(learn=True)
powermean_aggr = aggr.PowerMeanAggregation(learn=True)

# Exotic aggregations:
lstm_aggr = aggr.LSTMAggregation(in_channels=..., out_channels=...)
sort_aggr = aggr.SortAggregation(k=4)

然后,我们可以轻松地将这些聚合应用于一批可能大小不同的集合。 为此,一个index向量定义了从输入元素到它们在输出中位置的映射:

# Feature matrix holding 1000 elements with 64 features each:
x = torch.randn(1000, 64)

# Randomly assign elements to 100 sets:
index = torch.randint(0, 100, (1000, ))

output = mean_aggr(x, index)  #  Output shape: [100, 64]

值得注意的是,所有聚合操作共享同一组前向参数,如torch_geometric.nn.aggr.Aggregation基类中详细描述的那样。

提供的每个聚合都可以在MessagePassing中使用,也可以用于层次/全局池化以获得图级表示:

import torch
from torch_geometric.nn import MessagePassing

class MyConv(MessagePassing):
    def __init__(self, ...):
        # Use a learnable softmax neighborhood aggregation:
        super().__init__(aggr=aggr.SoftmaxAggregation(learn=True))

   def forward(self, x, edge_index):
       ....


class MyGNN(torch.nn.Module)
    def __init__(self, ...):
        super().__init__()

        self.conv = MyConv(...)
        # Use a global sort aggregation:
        self.global_pool = aggr.SortAggregation(k=4)
        self.classifier = torch.nn.Linear(...)

     def foward(self, x, edge_index, batch):
         x = self.conv(x, edge_index).relu()
         x = self.global_pool(x, batch)
         x = self.classifier(x)
         return x

此外,的聚合包引入了两个新概念: 首先,聚合可以通过查找表从纯字符串解析,遵循class-resolver库的设计原则,例如,只需将"median"传递给MessagePassing模块。 这将自动解析为MedianAggregation类:

class MyConv(MessagePassing):
    def __init__(self, ...):
        super().__init__(aggr="median")

其次,可以通过MultiAggregation模块将多个聚合组合和堆叠,以增强图神经网络(GNNs)的表示能力(Corso et al. (2020), Tailor et al. (2021)):

class MyConv(MessagePassing):
    def __init__(self, ...):
        # Combines a set of aggregations and concatenates their results,
        # i.e. its output will be `[num_nodes, 3 * out_channels]` here.
        # Note that the interface also supports automatic resolution.
        super().__init__(aggr=aggr.MultiAggregation(
            ['mean', 'std', aggr.SoftmaxAggregation(learn=True)]))

重要的是,MultiAggregation 提供了多种选项来组合其底层聚合的输出(例如,使用连接、求和、注意力等)通过其 mode 参数。 默认的 mode 执行连接操作("cat")。 要通过注意力进行组合,我们还需要指定 in_channelsout_channelsnum_heads

multi_aggr = aggr.MultiAggregation(
    aggrs=['mean', 'std'],
    mode='attn',
    mode_kwargs=dict(in_channels=64, out_channels=64, num_heads=4),
)

如果聚合以列表形式给出,它们将自动解析为MultiAggregation例如aggr=['mean', 'std', 'median']

最后,我们为SAGEConv层添加了完全支持的自定义聚合功能 — 只需覆盖其aggr参数,并利用聚合的力量在您的GNN中

注意

你可以在这篇博客文章中了解更多关于torch_geometric.nn.aggr包的信息。

Aggregation

用于实现自定义聚合的抽象基类。

MultiAggregation

执行一个或多个聚合器的聚合操作,并组合聚合结果,如"Principal Neighbourhood Aggregation for Graph Nets""Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions"论文中所述。

SumAggregation

一个聚合操作符,用于对一组元素中的特征进行求和。

MeanAggregation

一个聚合操作符,用于计算一组元素的特征平均值。

MaxAggregation

一个聚合操作符,它在一组元素中取特征方向的最大值。

MinAggregation

一个聚合操作符,它在一组元素中取特征方向的最小值。

MulAggregation

一个聚合操作符,用于在一组元素中乘以特征。

VarAggregation

一个聚合操作符,用于计算一组元素的特征方差。

StdAggregation

一个聚合操作符,用于计算一组元素的特征标准偏差。

SoftmaxAggregation

基于温度项的softmax聚合操作符,如"DeeperGCN: All You Need to Train Deeper GCNs"论文中所述。

PowerMeanAggregation

基于幂项的幂均值聚合操作符,如"DeeperGCN: All You Need to Train Deeper GCNs"论文中所述。

MedianAggregation

一个聚合操作符,返回一组数据的特征中位数。

QuantileAggregation

一个聚合操作符,返回集合 \(\mathcal{X}\) 的特征方向的 \(q\) 分位数。

LSTMAggregation

执行LSTM风格的聚合,其中要聚合的元素被解释为一个序列,如"Inductive Representation Learning on Large Graphs"论文中所述。

GRUAggregation

执行GRU聚合,其中要聚合的元素被解释为一个序列,如"具有自适应读数的图神经网络"论文中所述。

Set2Set

基于迭代内容注意力的Set2Set聚合操作符,如"顺序重要:集合的序列到序列"论文中所述。

DegreeScalerAggregation

结合一个或多个聚合器,并使用一个或多个缩放器转换其输出,如"Principal Neighbourhood Aggregation for Graph Nets"论文中介绍的那样。

SortAggregation

来自"An End-to-End Deep Learning Architecture for Graph Classification"论文的池化操作符,其中节点特征根据其最后一个特征通道按降序排序。

GraphMultisetTransformer

来自"Accurate Learning of Graph Representations with Graph Multiset Pooling"论文的图多集变换器池化操作符。

AttentionalAggregation

来自"Graph Matching Networks for Learning the Similarity of Graph Structured Objects"论文的软注意力聚合层。

EquilibriumAggregation

来自"Equilibrium Aggregation: Encoding Sets via Optimization"论文的平衡聚合层。

MLPAggregation

执行MLP聚合,其中要聚合的元素被展平为单个向量表示,然后通过多层感知器(MLP)进行处理,如"具有自适应读数的图神经网络"论文中所述。

DeepSetsAggregation

执行深度集合聚合,其中要聚合的元素首先通过多层感知器(MLP)\(\phi_{\mathbf{\Theta}}\)进行转换,求和,然后通过另一个MLP \(\rho_{\mathbf{\Theta}}\)进行转换,如"具有自适应读数的图神经网络"论文中所建议的。

SetTransformerAggregation

执行“Set Transformer”聚合,其中要聚合的元素通过多头注意力块进行处理,如“具有自适应读数的图神经网络”论文中所述。

LCMAggregation

来自"Learnable Commutative Monoids for Graph Neural Networks"论文的可交换幺半群聚合,其中元素使用具有\(\mathcal{O}(\log |\mathcal{V}|)\)深度的二叉树归约进行聚合。

VariancePreservingAggregation

执行来自"GNN-VPA: A Variance-Preserving Aggregation Strategy for Graph Neural Networks"论文的方差保持聚合(VPA)。

PatchTransformerAggregation

执行补丁变换器聚合,其中要聚合的元素通过跨补丁的多头注意力块进行处理,如"Simplifying Temporal Heterogeneous Network for Continuous-Time Link Prediction"论文中所述。

归一化层

BatchNorm

应用批量归一化处理一批特征,如"批量归一化:通过减少内部协变量偏移加速深度网络训练"论文中所述。

HeteroBatchNorm

应用批量归一化处理一批异质特征,如"批量归一化:通过减少内部协变量偏移加速深度网络训练"论文中所述。

InstanceNorm

对一批节点特征中的每个单独示例应用实例归一化,如"实例归一化:快速风格化的缺失成分"论文中所述。

LayerNorm

在一批特征中对每个单独的示例应用层归一化,如"Layer Normalization"论文中所述。

HeteroLayerNorm

"Layer Normalization"论文中所述,对一批异质特征中的每个单独示例应用层归一化。

GraphNorm

应用图规范化于单个图,如"GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training"论文中所述。

GraphSizeNorm

对一批节点特征中的每个单独图应用图大小归一化,如"基准测试图神经网络"论文中所述。

PairNorm

应用节点特征上的对归一化,如"PairNorm: Tackling Oversmoothing in GNNs"论文中所述。

MeanSubtractionNorm

应用层归一化,通过从输入中减去均值,如"Revisiting 'Over-smoothing' in Deep GCNs"论文中所述。

MessageNorm

应用消息归一化于聚合的消息,如"DeeperGCNs: All You Need to Train Deeper GCNs"论文中所述。

DiffGroupNorm

来自"Towards Deeper Graph Neural Networks with Differentiable Group Normalization"论文的可微分组归一化层,该层通过可学习的软聚类分配对节点特征进行分组归一化。

池化层

global_add_pool

通过跨节点维度添加节点特征,返回批处理级别的图级输出。

global_mean_pool

通过跨节点维度平均节点特征,返回批处理级别的图级输出。

global_max_pool

通过跨节点维度的通道方向最大值返回批处理级别的图级输出。

KNNIndex

一个基类,用于通过faiss库执行快速的\(k\)-最近邻搜索(\(k\)-NN)。

L2KNNIndex

通过faiss库基于\(L_2\)度量执行快速的\(k\)-最近邻搜索(\(k\)-NN)。

MIPSKNNIndex

通过faiss库执行基于最大内积的快速\(k\)-最近邻搜索(\(k\)-NN)。

ApproxL2KNNIndex

通过faiss库执行基于\(L_2\)度量的快速近似\(k\)-最近邻搜索(\(k\)-NN)。

ApproxMIPSKNNIndex

通过faiss库执行基于最大内积的快速近似\(k\)-最近邻搜索(\(k\)-NN)。

TopKPooling

\(\mathrm{top}_k\) 池化操作符来自 "Graph U-Nets", "Towards Sparse Hierarchical Graph Classifiers""Understanding Attention and Generalization in Graph Neural Networks" 论文。

SAGPooling

自注意力池化操作符来自"Self-Attention Graph Pooling""Understanding Attention and Generalization in Graph Neural Networks"论文。

EdgePooling

边缘池化操作符来自"Towards Graph Pooling by Edge Contraction""Edge Contraction Pooling for Graph Neural Networks"论文。

ClusterPooling

来自"基于边的图组件池化"论文的集群池化操作符。

ASAPooling

来自"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations"论文的自适应结构感知池化操作符。

PANPooling

基于路径积分的池化操作符,来自"基于路径积分的图神经网络卷积和池化"论文。

MemPooling

基于内存的池化层来自"基于内存的图网络"论文,该层基于软聚类分配学习粗化的图表示。

max_pool

根据cluster中定义的聚类,对由torch_geometric.data.Data对象给出的图进行池化和粗化。

avg_pool

根据cluster中定义的聚类,对由torch_geometric.data.Data对象给出的图进行池化和粗化。

max_pool_x

根据cluster中定义的聚类,对节点特征进行最大池化。

max_pool_neighbor_x

最大池化相邻节点特征,其中data.x中的每个特征被替换为中心节点及其邻居中具有最大值的特征值。

avg_pool_x

根据cluster中定义的聚类,平均池化节点特征。

avg_pool_neighbor_x

平均池化相邻节点的特征,其中data.x中的每个特征被替换为中心节点及其邻居的平均特征值。

graclus

一种来自"Weighted Graph Cuts without Eigenvectors: A Multilevel Approach"论文的贪心聚类算法,该算法选择一个未标记的顶点并将其与其未标记的邻居之一(最大化其边权重)进行匹配。

voxel_grid

体素网格池化来自,例如Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs 论文,它在点云上覆盖一个用户定义大小的规则网格,并将同一体素内的所有点聚类。

fps

一种采样算法,来自"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space"论文,该算法迭代地采样相对于其余点最远的点。

knn

y中的每个元素在x中找到k个最近的点。

knn_graph

Computes graph edges to the nearest k points.

approx_knn

y中的每个元素在x中找到k个近似最近的点。

approx_knn_graph

Computes graph edges to the nearest approximated k points.

radius

y中的每个元素找到x中距离r以内的所有点。

radius_graph

计算到给定距离内所有点的图边。

nearest

y中的每个元素找到x中的k个最近点。

反池化层

knn_interpolate

来自"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space"论文的k-NN插值。

Models

MLP

一个多层感知器(MLP)模型。

GCN

来自"Semi-supervised Classification with Graph Convolutional Networks"论文的图神经网络,使用GCNConv操作符进行消息传递。

GraphSAGE

来自"Inductive Representation Learning on Large Graphs"论文的图神经网络,使用SAGEConv操作符进行消息传递。

GIN

来自"图神经网络有多强大?"论文的图神经网络,使用GINConv操作符进行消息传递。

GAT

来自"Graph Attention Networks""How Attentive are Graph Attention Networks?"论文的图神经网络,分别使用GATConvGATv2Conv操作符进行消息传递。

PNA

来自"Principal Neighbourhood Aggregation for Graph Nets"论文的图神经网络,使用PNAConv操作符进行消息传递。

EdgeCNN

来自"Dynamic Graph CNN for Learning on Point Clouds"论文的图神经网络,使用EdgeConv操作符进行消息传递。

JumpingKnowledge

来自"Representation Learning on Graphs with Jumping Knowledge Networks"论文的跳跃知识层聚合模块。

HeteroJumpingKnowledge

A heterogeneous version of the JumpingKnowledge module.

MetaLayer

一个用于构建任何类型图网络的元层,灵感来自于"关系归纳偏差、深度学习与图网络"论文。

Node2Vec

来自"node2vec: Scalable Feature Learning for Networks"论文的Node2Vec模型,其中在给定图中采样长度为walk_length的随机游走,并通过负采样优化学习节点嵌入。

DeepGraphInfomax

来自"Deep Graph Infomax"论文的深度图信息最大化模型,基于用户定义的编码器和摘要模型\(\mathcal{E}\)\(\mathcal{R}\),以及一个损坏函数\(\mathcal{C}\)

InnerProductDecoder

来自"变分图自编码器"论文的内积解码器。

GAE

来自"变分图自动编码器"论文的图自动编码器模型,基于用户定义的编码器和解码器模型。

VGAE

来自"变分图自编码器"论文的变分图自编码器模型。

ARGA

来自"Adversarially Regularized Graph Autoencoder for Graph Embedding"论文的对抗性正则化图自动编码器模型。

ARGVA

来自"Adversarially Regularized Graph Autoencoder for Graph Embedding"论文的对抗性正则化变分图自编码器模型。

SignedGCN

来自"Signed Graph Convolutional Network"论文的签名图卷积网络模型。

RENet

来自"Recurrent Event Network for Reasoning over Temporal Knowledge Graphs"论文的循环事件网络模型。

GraphUNet

来自"Graph U-Nets"论文的Graph U-Net模型,该模型实现了类似U-Net的架构,包含图池化和反池化操作。

SchNet

来自"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions"论文的连续滤波器卷积神经网络SchNet,它使用了以下形式的交互块。

DimeNet

来自"分子图的方向性消息传递"论文的方向性消息传递神经网络(DimeNet)。

DimeNetPlusPlus

来自"快速且不确定性感知的方向性消息传递用于非平衡分子"论文的DimeNet++。

to_captum_model

将模型转换为可用于Captum归因方法的模型。

to_captum_input

给定 x, edge_indexmask_type,将其转换为可在 Captum 归因方法中使用的格式。

captum_output_to_dicts

Captum归因方法的输出转换为两个字典,其中包含节点和边的归因张量。

MetaPath2Vec

来自"metapath2vec: Scalable Representation Learning for Heterogeneous Networks"论文的MetaPath2Vec模型,其中在异质图中基于给定的metapath进行随机游走采样,并通过负采样优化学习节点嵌入。

DeepGCNLayer

来自"DeepGCNs: Can GCNs Go as Deep as CNNs?""All You Need to Train Deeper GCNs"论文的跳跃连接操作。

TGNMemory

来自"Temporal Graph Networks for Deep Learning on Dynamic Graphs"论文的时间图网络(TGN)记忆模型。

LabelPropagation

标签传播操作符,首次在"从带标签和无标签数据中学习标签传播"论文中引入。

CorrectAndSmooth

来自"Combining Label Propagation And Simple Models Out-performs Graph Neural Networks"论文的正确且平滑(C&S)后处理模型,其中软预测\(\mathbf{Z}\)(从简单的基础预测器获得)首先基于真实训练标签信息\(\mathbf{Y}\)和残差传播进行校正。

AttentiveFP

Attentive FP模型用于分子表示学习,源自"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism"论文,基于图注意力机制。

RECT_L

RECT模型,其监督学习的RECT-L部分,来自"完全不平衡标签的网络嵌入"论文。

LINKX

来自"大规模非同质图学习:新基准和强大的简单方法"论文的LINKX模型。

LightGCN

来自"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation"论文的LightGCN模型。

MaskLabel

来自"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification"论文的标签嵌入和掩码层。

GroupAddRev

来自"Graph Neural Networks with 1000 Layers"论文的分组可逆GNN模块。

GNNFF

来自"Accurate and scalable graph neural network force field and molecular dynamics with direct force architecture"论文的图神经网络力场(GNNFF)。

PMLP

来自"图神经网络本质上是优秀的泛化者:通过桥接GNNs和MLPs的见解"论文的P(ropagational)MLP模型。

NeuralFingerprint

来自"Convolutional Networks on Graphs for Learning Molecular Fingerprints"论文的神经指纹模型,用于生成分子的指纹。

ViSNet

一个模块,实现了来自"Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing"论文的等变向量-标量交互图神经网络(ViSNet)。

GRetriever

来自"G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering"论文的G-Retriever模型。

GITMol

来自"GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text"论文的GITMol模型。

MoleculeGPT

来自"MoleculeGPT: 用于分子性质预测的指令跟随大型语言模型"论文的MoleculeGPT模型。

GLEM

这个GNN+LM联合训练模型基于来自"通过变分推理在大规模文本属性图上的学习"论文的GLEM。

KGE 模型

KGEModel

用于实现自定义KGE模型的抽象基类。

TransE

来自"Translating Embeddings for Modeling Multi-Relational Data"论文的TransE模型。

ComplEx

来自"Complex Embeddings for Simple Link Prediction"论文的ComplEx模型。

DistMult

来自"嵌入实体和关系以在知识库中进行学习和推理"论文的DistMult模型。

RotatE

来自"RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space"论文的RotatE模型。

编码

class PositionalEncoding(out_channels: int, base_freq: float = 0.0001, granularity: float = 1.0)[source]

来自“Attention Is All You Need”论文的位置编码方案。

\[ \begin{align}\begin{aligned}PE(x)_{2 \cdot i} &= \sin(x / 10000^{2 \cdot i / d})\\PE(x)_{2 \cdot i + 1} &= \cos(x / 10000^{2 \cdot i / d})\end{aligned}\end{align} \]

其中 \(x\) 是位置,\(i\) 是维度。

Parameters:
  • out_channels (int) – 每个输出样本的大小 \(d\)

  • base_freq (float, optional) – 正弦函数的基频。(默认值: 1e-4)

  • granularity (float, optional) – 位置的粒度。如果设置为较小的值,编码器将捕捉到更细粒度的位置变化。(默认值:1.0

reset_parameters()[source]
forward(x: Tensor) Tensor[source]
Return type:

Tensor

class TemporalEncoding(out_channels: int)[source]

来自“我们真的需要复杂的时间网络模型架构吗?”论文的时间编码函数。

它首先将每个条目映射到一个具有指数递减值的向量, 然后使用余弦函数将所有值投影到范围 \([-1, 1]\)

\[y_{i} = \cos \left(x \cdot \sqrt{d}^{-(i - 1)/\sqrt{d}} \right)\]

其中 \(d\) 定义了输出特征的维度,且 \(1 \leq i \leq d\)

Parameters:

out_channels (int) – 每个输出样本的大小 \(d\)

reset_parameters()[source]
forward(x: Tensor) Tensor[source]
Return type:

Tensor

Functional

bro

来自"Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity"论文的批量表示正交性惩罚。

gini

来自"Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity"论文的基尼系数。

密集卷积层

DenseGCNConv

See torch_geometric.nn.conv.GCNConv.

DenseGINConv

See torch_geometric.nn.conv.GINConv.

DenseGraphConv

See torch_geometric.nn.conv.GraphConv.

DenseSAGEConv

See torch_geometric.nn.conv.SAGEConv.

DenseGATConv

See torch_geometric.nn.conv.GATConv.

密集池化层

dense_diff_pool

来自"Hierarchical Graph Representation Learning with Differentiable Pooling"论文的可微分池化操作符。

dense_mincut_pool

来自"图神经网络中的谱聚类用于图池化"论文的MinCut池化操作符。

DMoNPooling

来自"Graph Clustering with Graph Neural Networks"论文的谱模块化池化操作符。

模型转换

class Transformer(module: Module, input_map: Optional[Dict[str, str]] = None, debug: bool = False)[source]

一个 Transformer 逐节点执行 FX 图,对每个节点应用转换,并生成一个新的 torch.nn.Module。 它暴露了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 完全以符号方式工作。

Transformer 类中的方法可以被重写以自定义转换行为。

transform()
    +-- Iterate over each node in the graph
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- call_message_passing_module()
        +-- call_global_pooling_module()
        +-- output()
    +-- Erase unused nodes in the graph
    +-- Iterate over each children module
        +-- init_submodule()

torch.fx.Transformer类相比, Transformer暴露了额外的功能:

  1. 它将call_module()细分为调用常规torch.nn.Module的节点(call_module()),调用MessagePassing模块的节点(call_message_passing_module()),或调用GlobalPooling模块的节点(call_global_pooling_module())。

  2. 它允许通过 init_submodule() 自定义或初始化新的子模块

  3. 它允许通过is_edge_level()推断一个节点返回的是节点级别还是边级别的信息。

Parameters:
  • module (torch.nn.Module) – 要转换的模块。

  • input_map (Dict[str, str], optional) – 一个字典,用于保存关于module.forward输入参数类型的信息。 例如,如果arg是一个节点级别的参数,那么 input_map['arg'] = 'node',否则 input_map['arg'] = 'edge'。 如果input_map没有进一步指定,将尝试自动确定输入参数的正确类型。 (默认值: None)

  • debug (bool, 可选) – 如果设置为 True,将在调试模式下执行转换。(默认值:False

placeholder(node: Node, target: Any, name: str)[source]
get_attr(node: Node, target: Any, name: str)[source]
call_message_passing_module(node: Node, target: Any, name: str)[source]
call_global_pooling_module(node: Node, target: Any, name: str)[source]
call_module(node: Node, target: Any, name: str)[source]
call_method(node: Node, target: Any, name: str)[source]
call_function(node: Node, target: Any, name: str)[source]
output(node: Node, target: Any, name: str)[source]
init_submodule(module: Module, target: str) Module[source]
Return type:

Module

transform() GraphModule[source]

转换 self.module 并返回一个转换后的 torch.fx.GraphModule

Return type:

GraphModule

is_node_level(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

is_edge_level(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

is_graph_level(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

has_node_level_arg(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

has_edge_level_arg(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

has_graph_level_arg(node: Node) bool[source]
Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

replace_all_uses_with(to_replace: Node, replace_with: Node)[source]
to_hetero(module: Module, metadata: Tuple[List[str], List[Tuple[str, str, str]]], aggr: str = 'sum', input_map: Optional[Dict[str, str]] = None, debug: bool = False) GraphModule[source]

将同质的GNN模型转换为其异质等价模型,其中为metadata[0]中的每个节点类型学习节点表示,并在metadata[1]中的每个边类型之间交换消息,如“使用图卷积网络建模关系数据”论文中所描述。

import torch
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), 32)
        self.conv2 = SAGEConv((32, 32), 32)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

model = GNN()

node_types = ['paper', 'author']
edge_types = [
    ('paper', 'cites', 'paper'),
    ('paper', 'written_by', 'author'),
    ('author', 'writes', 'paper'),
]
metadata = (node_types, edge_types)

model = to_hetero(model, metadata)
model(x_dict, edge_index_dict)

其中 x_dictedge_index_dict 表示分别保存每种节点类型和边类型的节点特征和边连接信息的字典。

下图显示了左侧同质模型的原始计算图,以及右侧新获得的异质模型的计算图:

../_images/to_hetero.svg

通过to_hetero()转换模型。

在这里,每个MessagePassing实例 \(f_{\theta}^{(\ell)}\) 被复制并存储在一个集合 \(\{ f_{\theta}^{(\ell, r)} : r \in \mathcal{R} \}\) 中(每个关系在 \(\mathcal{R}\) 中有一个实例),并且第 \(\ell\) 层的消息传递是通过以下方式进行的:

\[\mathbf{h}^{(\ell)}_v = \bigoplus_{r \in \mathcal{R}} f_{\theta}^{(\ell, r)} ( \mathbf{h}^{(\ell - 1)}_v, \{ \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}),\]

其中 \(\mathcal{N}^{(r)}(v)\) 表示在关系 \(r \in \mathcal{R}\) 下节点 \(v \in \mathcal{V}\) 的邻域,\(\bigoplus\) 表示用于聚合由不同关系生成的节点嵌入的聚合方案 aggr"sum""mean""min""max""mul")。

Parameters:
  • 模块 (torch.nn.Module) – 要转换的同类模型。

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • aggr (str, 可选) – 用于分组节点嵌入的聚合方案,这些嵌入由不同关系生成 ("sum", "mean", "min", "max", "mul")。(默认: "sum")

  • input_map (Dict[str, str], optional) – 一个字典,用于保存关于module.forward输入参数类型的信息。 例如,如果arg是一个节点级别的参数,那么 input_map['arg'] = 'node',否则 input_map['arg'] = 'edge'。 如果input_map没有进一步指定,将尝试自动确定输入参数的正确类型。 (默认值:None

  • debug (bool, optional) – If set to True, will perform transformation in debug mode. (default: False)

Return type:

GraphModule

to_hetero_with_bases(module: Module, metadata: Tuple[List[str], List[Tuple[str, str, str]]], num_bases: int, in_channels: Optional[Dict[str, int]] = None, input_map: Optional[Dict[str, str]] = None, debug: bool = False) GraphModule[source]

通过“使用图卷积网络建模关系数据”论文中引入的基分解技术,将同质的GNN模型转换为其异质等价模型。

为此,异构图被映射到一个类型化的同构图中,其中其特征表示被对齐并分组为单一表示。模型内的所有GNN层随后将通过基础分解正则化执行消息传递。这种转换在高度多关系的数据中特别有用,使得参数的数量不再依赖于输入图的关系数量:

import torch
from torch_geometric.nn import SAGEConv, to_hetero_with_bases

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv((16, 16), 32)
        self.conv2 = SAGEConv((32, 32), 32)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

model = GNN()

node_types = ['paper', 'author']
edge_types = [
    ('paper', 'cites', 'paper'),
    ('paper', 'written_by', 'author'),
    ('author', 'writes', 'paper'),
]
metadata = (node_types, edge_types)

model = to_hetero_with_bases(model, metadata, num_bases=3,
                             in_channels={'x': 16})
model(x_dict, edge_index_dict)

其中 x_dictedge_index_dict 表示分别保存每种节点类型和边类型的节点特征和边连接信息的字典。 如果为特定输入参数提供了 in_channels,则其异构特征信息首先会与给定的维度对齐。

下图左侧显示了同质模型的原始计算图,右侧显示了新获得的正则化异质模型的计算图:

../_images/to_hetero_with_bases.svg

通过 to_hetero_with_bases() 转换模型。

在这里,每个MessagePassing实例 \(f_{\theta}^{(\ell)}\)被复制num_bases次并 存储在一个集合\(\{ f_{\theta}^{(\ell, b)} : b \in \{ 1, \ldots, B \} \}\)中(每个基在 num_bases中有一个实例),并且在层\(\ell\)中的消息传递是通过

\[\mathbf{h}^{(\ell)}_v = \sum_{r \in \mathcal{R}} \sum_{b=1}^B f_{\theta}^{(\ell, b)} ( \mathbf{h}^{(\ell - 1)}_v, \{ a^{(\ell)}_{r, b} \cdot \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}),\]

其中 \(\mathcal{N}^{(r)}(v)\) 表示在关系 \(r \in \mathcal{R}\)\(v \in \mathcal{V}\) 的邻域。值得注意的是,只有可训练的基础系数 \(a^{(\ell)}_{r, b}\) 依赖于 \(\mathcal{R}\) 中的关系。

Parameters:
  • module (torch.nn.Module) – The homogeneous model to transform.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • num_bases (int) – 要使用的基础数量。

  • in_channels (Dict[str, int], optional) – 一个字典,包含有关module.forward输入参数的期望输入特征维度的信息。 如果为特定输入参数提供了in_channels,则其异构特征信息首先会与给定的维度对齐。 这允许处理不同类型节点和边特征的不同特征维度。(默认值:None

  • input_map (Dict[str, str], optional) – A dictionary holding information about the type of input arguments of module.forward. For example, in case arg is a node-level argument, then input_map['arg'] = 'node', and input_map['arg'] = 'edge' otherwise. In case input_map is not further specified, will try to automatically determine the correct type of input arguments. (default: None)

  • debug (bool, optional) – If set to True, will perform transformation in debug mode. (default: False)

Return type:

GraphModule

DataParallel 层

class DataParallel(module, device_ids=None, output_device=None, follow_batch=None, exclude_keys=None)[source]

在模块级别实现数据并行。

这个容器通过将一系列torch_geometric.data.Data对象分割并将它们复制为torch_geometric.data.Batch对象到每个设备上,来并行化给定module的应用。 在前向传播过程中,模块在每个设备上被复制,每个副本处理一部分输入。 在反向传播过程中,来自每个副本的梯度被汇总到原始模块中。

批量大小应大于使用的GPU数量。

并行化的 module 必须将其参数和缓冲区放在 device_ids[0] 上。

注意

您需要使用torch_geometric.loader.DataListLoader来加载这个模块。

警告

建议使用 torch.nn.parallel.DistributedDataParallel 而不是 DataParallel 进行多GPU训练。 DataParallel 通常比 DistributedDataParallel 慢得多,即使在单台机器上也是如此。 查看 这里 以了解如何将 DistributedDataParallel 结合使用的示例。

Parameters:
  • module (Module) – 要并行化的模块。

  • device_ids (整数列表torch.device) – CUDA 设备。 (默认值:所有设备)

  • output_device (inttorch.device) – 输出的设备位置。 (默认: device_ids[0])

  • follow_batch (listtuple, 可选) – 为列表中的每个键创建分配批处理向量。(默认值: None)

  • exclude_keys (listtuple, 可选) – 将排除列表中的每个键。(默认值:None

模型中心

class PyGModelHubMixin(model_name: str, dataset_name: str, model_kwargs: Dict)[source]

一个用于保存和加载模型到 Huggingface Model Hub 的 mixin。

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec
from torch_geometric.nn.model_hub import PyGModelHubMixin

# Define your class with the mixin:
class N2V(Node2Vec, PyGModelHubMixin):
    def __init__(self,model_name, dataset_name, model_kwargs):
        Node2Vec.__init__(self,**model_kwargs)
        PyGModelHubMixin.__init__(self, model_name,
            dataset_name, model_kwargs)

# Instantiate your model:
n2v = N2V(model_name='node2vec',
    dataset_name='Cora', model_kwargs=dict(
    edge_index=data.edge_index, embedding_dim=128,
    walk_length=20, context_size=10, walks_per_node=10,
    num_negative_samples=1, p=1, q=1, sparse=True))

# Train the model:
...

# Push to the HuggingFace hub:
repo_id = ...  # your repo id
n2v.save_pretrained(
    local_file_path,
    push_to_hub=True,
    repo_id=repo_id,
 )

# Load the model for inference:
# The required arguments are the repo id/local folder, and any model
# initialisation arguments that are not native python types (e.g
# Node2Vec requires the edge_index argument which is not stored in the
# model hub).
model = N2V.from_pretrained(
    repo_id,
    model_name='node2vec',
    dataset_name='Cora',
    edge_index=data.edge_index,
)
Parameters:
  • model_name (str) – 模型的名称。

  • dataset_name (str) – 模型训练所针对的数据集名称。

  • model_kwargs (Dict[str, Any]) – 初始化模型的参数。

save_pretrained(save_directory: Union[str, Path], push_to_hub: bool = False, repo_id: Optional[str] = None, **kwargs)[source]

将训练好的模型保存到本地目录或HuggingFace模型中心。

Parameters:
  • save_directory (str) – 保存权重的目录。

  • push_to_hub (bool, 可选) – 如果为 True,则将模型推送到 HuggingFace 模型中心。(默认值:False

  • repo_id (str, optional) – 仓库在中心中的名称。 如果未提供,将默认为你命名空间中的 save_directory 名称。(默认: None)

  • **kwargs – 传递给 huggingface_hub.ModelHubMixin.save_pretrained() 的额外关键字参数。

classmethod from_pretrained(pretrained_model_name_or_path: str, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[str] = None, local_files_only: bool = False, **model_kwargs) Any[source]

从HuggingFace中心下载并实例化一个模型。

Parameters:
  • pretrained_model_name_or_path (str) –

    可以是以下之一:

    • 托管在HuggingFace hub中的预训练模型的model_id

    • 您可以通过在model_id末尾添加@来加载特定版本的模型。

    • 包含保存的模型权重的目录路径。

    • 如果您同时提供了配置config和状态字典state_dict,则为None

  • force_download (bool, optional) – 是否强制 (重新)下载模型权重和配置文件, 覆盖缓存版本(如果存在)。 (默认值:False)

  • resume_download (bool, optional) – 是否删除未完全接收的文件。如果存在这样的文件,将尝试恢复下载。(默认值:False

  • proxies (Dict[str, str], optional) – 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://host': 'foo.bar:4012'}。 代理服务器在每个请求中使用。(默认值:None

  • token (strbool, 可选) – 用于远程文件的HTTP承载授权的token。如果设置为 True,将使用运行 transformers-cli login 时生成的token(存储在 huggingface 中)。如果你想使用私有模型,这是必需的。(默认值:None

  • cache_dir (str, optional) – 如果不使用标准缓存,则应缓存下载的模型配置的目录路径。(默认值: None)

  • local_files_only (bool, 可选) – 是否仅查看本地文件,不尝试下载模型。 (默认: False)

  • **model_kwargs – 在初始化期间传递给模型的额外关键字参数。

Return type:

Any

模型摘要

summary(model: Module, *args, max_depth: int = 3, leaf_module: Optional[Union[Module, List[Module]]] = 'MessagePassing', **kwargs) str[source]

总结给定的torch.nn.Module。 总结的信息包括(1)层名称,(2)输入和输出形状,以及(3)参数数量。

import torch
from torch_geometric.nn import GCN, summary

model = GCN(128, 64, num_layers=2, out_channels=32)
x = torch.randn(100, 128)
edge_index = torch.randint(100, size=(2, 20))

print(summary(model, x, edge_index))
+---------------------+---------------------+--------------+--------+
| Layer               | Input Shape         | Output Shape | #Param |
|---------------------+---------------------+--------------+--------|
| GCN                 | [100, 128], [2, 20] | [100, 32]    | 10,336 |
| ├─(act)ReLU         | [100, 64]           | [100, 64]    | --     |
| ├─(convs)ModuleList | --                  | --           | 10,336 |
| │    └─(0)GCNConv   | [100, 128], [2, 20] | [100, 64]    | 8,256  |
| │    └─(1)GCNConv   | [100, 64], [2, 20]  | [100, 32]    | 2,080  |
+---------------------+---------------------+--------------+--------+
Parameters:
  • model (torch.nn.Module) – 要总结的模型。

  • *argsmodel 的参数。

  • max_depth (int, optional) – 要显示的嵌套层深度。 任何超过此深度的层将不会显示在摘要中。(默认值:3

  • leaf_module (torch.nn.Module[torch.nn.Module], 可选) – 被视为叶子模块的模块,其子模块被排除在摘要之外。 (默认: MessagePassing)

  • **kwargsmodel 的额外参数。

Return type:

str