dgl.function

这个子包包含了DGL提供的所有内置函数。内置函数是DGL推荐的表达不同类型第2章:消息传递计算的方式(即通过update_all())或从节点特征计算边特征(即通过apply_edges())。内置函数以符号方式描述节点和边的计算,而不进行任何实际计算,因此DGL可以分析并将它们映射到高效的低级内核。以下是一些示例:

import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.num_edges(), 1))  # each edge has feature size 1
# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
# compute edge embedding by multiplying source and destination node embeddings
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))

fn.copy_u, fn.u_mul_e, fn.u_mul_v 是内置的消息函数,而 fn.sumfn.max 是内置的归约函数。DGL 的惯例是使用 u, ve 分别表示源节点、目标节点和边。 例如,copy_u 告诉 DGL 将源节点数据复制为消息; u_mul_e 告诉 DGL 将源节点特征与边特征相乘。

要定义一个一元消息函数(例如 copy_u),请指定一个输入特征名称和一个输出消息名称。要定义一个二元消息函数(例如 u_mul_e),请指定两个输入特征名称和一个输出消息名称。在计算过程中,消息函数将读取给定名称下的数据,执行计算,并使用输出名称返回输出。例如,上面的 fn.u_mul_e('h', 'w', 'm') 与以下用户定义的函数相同:

def udf_u_mul_e(edges):
   return {'m' : edges.src['h'] * edges.data['w']}

要定义一个reduce函数,需要指定一个输入消息名称和一个输出节点特征名称。例如,上面的fn.max('m', 'h_max')与以下用户定义的函数相同:

def udf_max(nodes):
   return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}

所有二进制消息函数都支持广播,这是一种将元素级操作扩展到具有不同形状的张量输入的机制。DGL通常遵循NumPyPyTorch的标准广播语义。以下是一些示例:

import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph

# case 1
g.ndata['h'] = th.randn((g.num_nodes(), 10))
g.edata['w'] = th.randn((g.num_edges(), 1))
# OK, valid broadcasting between feature shapes (10,) and (1,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new']  # shape: (g.num_nodes(), 10)

# case 2
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 10))
# OK, valid broadcasting between feature shapes (5, 10) and (10,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new']  # shape: (g.num_nodes(), 5, 10)

# case 3
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 5))
# NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
# shapes are aligned from right
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))

# case 3
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
# OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
g.apply_edges(fn.u_add_v('h1', 'h2', 'x'))  # apply_edges also supports broadcasting
g.edata['x']  # shape: (g.num_edges(), 5, 10)

# case 4
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
# OK, u_dot_v supports broadcasting but requires the last dimension to match
g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
g.edata['x']  # shape: (g.num_edges(), 5, 10, 1)

DGL 内置函数

以下是所有DGL内置函数的速查表。

类别

函数

备忘录

一元消息函数

copy_u

copy_e

二进制消息函数

u_add_v, u_sub_v, u_mul_v, u_div_v, u_dot_v

u_add_e, u_sub_e, u_mul_e, u_div_e, u_dot_e

v_add_u, v_sub_u, v_mul_u, v_div_u, v_dot_u

v_add_e, v_sub_e, v_mul_e, v_div_e, v_dot_e

e_add_u, e_sub_u, e_mul_u, e_div_u, e_dot_u

e_add_v, e_sub_v, e_mul_v, e_div_v, e_dot_v

Reduce 函数

max

min

sum

mean

消息函数

copy_u(u, out)

内置消息函数,使用源节点特征计算消息。

copy_e(e, out)

内置消息函数,使用边缘特征计算消息。

u_add_v(lhs_field, rhs_field, out)

内置消息函数,通过在u和v的特征之间执行元素级加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到一个新的形状并执行元素级操作。

u_sub_v(lhs_field, rhs_field, out)

内置消息函数,如果特征具有相同的形状,则通过对u和v的特征执行逐元素减法来计算边上的消息;否则,它首先将特征广播到新形状并执行逐元素操作。

u_mul_v(lhs_field, rhs_field, out)

内置消息函数,通过执行u和v特征之间的元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到一个新的形状并执行元素操作。

u_div_v(lhs_field, rhs_field, out)

内置消息函数,通过执行u和v特征之间的元素除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

u_add_e(lhs_field, rhs_field, out)

内置消息函数,通过在u和e的特征之间执行元素级加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素级操作。

u_sub_e(lhs_field, rhs_field, out)

内置消息函数,通过执行u和e特征之间的元素减法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

u_mul_e(lhs_field, rhs_field, out)

内置消息函数,通过在u和e的特征之间执行元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

u_div_e(lhs_field, rhs_field, out)

内置消息函数,通过执行u和e特征之间的逐元素除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

v_add_u(lhs_field, rhs_field, out)

内置消息函数,通过在v和u的特征之间执行元素级加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素级操作。

v_sub_u(lhs_field, rhs_field, out)

内置消息函数,通过在v和u的特征之间执行元素减法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

v_mul_u(lhs_field, rhs_field, out)

内置消息函数,通过执行v和u特征之间的元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

v_div_u(lhs_field, rhs_field, out)

内置消息函数,通过在v和u的特征之间执行元素级除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素级操作。

v_add_e(lhs_field, rhs_field, out)

内置消息函数,通过在v和e的特征之间执行元素级加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到一个新的形状并执行元素级操作。

v_sub_e(lhs_field, rhs_field, out)

内置消息函数,通过在v和e的特征之间执行元素减法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

v_mul_e(lhs_field, rhs_field, out)

内置消息函数,通过执行v和e特征之间的元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

v_div_e(lhs_field, rhs_field, out)

内置消息函数,通过执行v和e特征之间的逐元素除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

e_add_u(lhs_field, rhs_field, out)

内置消息函数,通过在e和u的特征之间执行元素级加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素级操作。

e_sub_u(lhs_field, rhs_field, out)

内置消息函数,通过执行e和u特征之间的元素减法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

e_mul_u(lhs_field, rhs_field, out)

内置消息函数,通过执行e和u特征之间的元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

e_div_u(lhs_field, rhs_field, out)

内置消息函数,通过执行e和u特征之间的逐元素除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行逐元素操作。

e_add_v(lhs_field, rhs_field, out)

内置消息函数,通过执行e和v特征之间的元素加法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

e_sub_v(lhs_field, rhs_field, out)

内置消息函数,通过执行e和v特征之间的元素减法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素操作。

e_mul_v(lhs_field, rhs_field, out)

内置消息函数,通过执行e和v特征之间的元素乘法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素操作。

e_div_v(lhs_field, rhs_field, out)

内置消息函数,通过执行e和v特征之间的逐元素除法来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

u_dot_v(lhs_field, rhs_field, out)

内置消息函数,通过执行u和v特征之间的逐元素点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

u_dot_e(lhs_field, rhs_field, out)

内置消息函数,通过执行u和e特征之间的逐元素点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

v_dot_e(lhs_field, rhs_field, out)

内置消息函数,通过在v和e的特征之间执行元素级点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行元素级操作。

v_dot_u(lhs_field, rhs_field, out)

内置消息函数,通过在v和u的特征之间执行元素级点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行元素级操作。

e_dot_u(lhs_field, rhs_field, out)

内置消息函数,通过执行e和u特征之间的逐元素点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新形状并执行逐元素操作。

e_dot_v(lhs_field, rhs_field, out)

内置消息函数,通过执行e和v特征之间的逐元素点积来计算边上的消息,如果特征具有相同的形状;否则,它首先将特征广播到新的形状并执行逐元素操作。

Reduce函数

sum(msg, out)

内置的reduce函数,通过求和来聚合消息。

max(msg, out)

内置的reduce函数,通过最大值聚合消息。

min(msg, out)

内置的reduce函数,用于通过最小值聚合消息。

mean(msg, out)

内置的reduce函数,通过平均值聚合消息。