torch_geometric.nn.models.MetaLayer

class MetaLayer(edge_model: Optional[Module] = None, node_model: Optional[Module] = None, global_model: Optional[Module] = None)[source]

Bases: Module

一个用于构建任何类型图网络的元层,灵感来自于 “关系归纳偏差、深度学习与图网络”论文。

图网络将图作为输入并返回更新后的图作为输出(具有相同的连接性)。 输入图具有节点特征 x、边特征 edge_attr 以及图级特征 u。 输出图具有相同的结构,但更新了特征。

通过调用模块 edge_modelnode_modelglobal_model,分别更新边特征、节点特征以及全局特征。

为了允许批量图处理,所有可调用函数都接受一个额外的参数 batch,它决定了边或节点到其特定图的分配。

Parameters:
  • edge_model (torch.nn.Module, 可选) – 一个可调用对象,用于根据其源节点和目标节点的特征、当前边的特征以及全局特征来更新图的边特征。 (默认: None)

  • node_model (torch.nn.Module, optional) – 一个可调用对象,用于根据当前节点特征、图连接性、边特征和全局特征更新图的节点特征。 (默认: None)

  • global_model (torch.nn.Module, optional) – 一个可调用对象,用于根据图的节点特征、图的连接性、边的特征和当前的全局特征来更新图的全局特征。 (默认: None)

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.utils import scatter
from torch_geometric.nn import MetaLayer

class EdgeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, src, dst, edge_attr, u, batch):
        # src, dst: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dst, edge_attr, u[batch]], 1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
        self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter(out, col, dim=0, dim_size=x.size(0),
                      reduce='mean')
        out = torch.cat([x, out, u[batch]], dim=1)
        return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = torch.cat([
            u,
            scatter(x, batch, dim=0, reduce='mean'),
        ], dim=1)
        return self.global_mlp(out)

op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
forward(x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, u: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor], Optional[Tensor]][source]

前向传播。

Parameters:
  • x (torch.Tensor) – The node features.

  • edge_index (torch.Tensor) – 边的索引。

  • edge_attr (torch.Tensor, optional) – 边的特征。 (default: None)

  • u (torch.Tensor, optional) – 全局图特征。 (默认: None)

  • batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点分配到一个特定的图中。(默认: None)

Return type:

Tuple[Tensor, Optional[Tensor], Optional[Tensor]]

reset_parameters()[source]

重置模块的所有可学习参数。