用户自定义函数

用户定义的函数(UDFs)允许在消息传递中进行任意计算(参见第2章:消息传递),并使用apply_edges()更新边特征。当dgl.function无法实现所需的计算时,它们提供了更多的灵活性。

边级别的用户自定义函数

可以在消息传递中使用边缘用户定义的函数作为消息函数,或者在apply_edges()中应用一个函数。它接受一批边作为输入,并返回每条边的消息(在消息传递中)或特征(在apply_edges()中)。该函数可以在计算中结合边及其端节点的特征。

形式上,它采用以下形式

def edge_udf(edges):
    """
    Parameters
    ----------
    edges : EdgeBatch
        A batch of edges.

    Returns
    -------
    dict[str, tensor]
        The messages or edge features generated. It maps a message/feature name to the
        corresponding messages/features of all edges in the batch. The order of the
        messages/features is the same as the order of the edges in the input argument.
    """

DGL 在内部生成 EdgeBatch 实例,这些实例暴露了以下接口用于定义 edge_udf

EdgeBatch.src

返回批次中边的源节点特征的视图。

EdgeBatch.dst

返回批次中边的目标节点特征的视图。

EdgeBatch.data

返回批次中边的特征的视图。

EdgeBatch.edges()

返回批次中的边。

EdgeBatch.batch_size()

返回批次中的边数。

节点级别的用户自定义函数

可以使用节点级别的用户定义函数作为消息传递中的归约函数。它以一批节点作为输入,并返回每个节点的更新特征。它可以结合当前节点特征和节点接收到的消息。形式上,它采用以下形式

def node_udf(nodes):
    """
    Parameters
    ----------
    nodes : NodeBatch
        A batch of nodes.

    Returns
    -------
    dict[str, tensor]
        The updated node features. It maps a feature name to the corresponding features of
        all nodes in the batch. The order of the nodes is the same as the order of the nodes
        in the input argument.
    """

DGL 在内部生成 NodeBatch 实例,这些实例暴露了以下接口用于定义 node_udf

NodeBatch.data

返回批次中节点的节点特征的视图。

NodeBatch.mailbox

返回接收到的消息的视图。

NodeBatch.nodes()

返回批次中的节点。

NodeBatch.batch_size()

返回批次中的节点数量。

使用用户定义函数进行消息传递的度数分桶

DGL 使用一种基于度数的分桶机制来进行带有用户定义函数(UDFs)的消息传递。它将具有相同入度的节点分组,并为每组节点调用消息传递。因此,不应假设 NodeBatch 实例的批量大小。

对于一批节点,DGL将每个节点的传入消息沿着第二维度堆叠,按边ID排序。示例如下:

>>> import dgl
>>> import torch
>>> import dgl.function as fn
>>> g = dgl.graph(([1, 3, 5, 0, 4, 2, 3, 3, 4, 5], [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]))
>>> g.edata['eid'] = torch.arange(10)
>>> def reducer(nodes):
...     print(nodes.mailbox['eid'])
...     return {'n': nodes.mailbox['eid'].sum(1)}
>>> g.update_all(fn.copy_e('eid', 'eid'), reducer)
tensor([[5, 6],
        [8, 9]])
tensor([[3, 7, 2],
        [0, 1, 4]])

本质上,节点#2和节点#3被分组到一个入度为2的桶中,而节点#0和节点#1被分组到一个入度为3的桶中。在每个桶内,边按每个节点的边ID排序。