dgl.DGLGraph.apply_edges
- DGLGraph.apply_edges(func, edges='__ALL__', etype=None)[source]
通过提供的函数更新指定边的特征。
- Parameters:
func (dgl.function.BuiltinFunction 或 可调用) – 用于生成新边特征的函数。它必须是 DGL 内置函数 或 用户自定义函数。
edges (edges) –
要更新特征的边。允许的输入格式有:
int
: 单个边ID。Int Tensor: 每个元素是一个边ID。张量必须具有与图相同的设备类型和ID数据类型。
iterable[int]: 每个元素是一个边ID。
(Tensor, Tensor): 节点张量格式,其中两个张量的第i个元素指定一条边。
(iterable[int], iterable[int]): 类似于节点张量格式,但将边端点存储在python可迭代对象中。
默认值指定图中的所有边。
etype (str or (str, str, str), optional) –
The type name of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
注释
DGL 建议使用 DGL 的内置函数作为
func
参数, 因为在这种情况下,DGL 将调用高效的内核,避免将节点特征复制到边特征。示例
以下示例使用PyTorch后端。
>>> import dgl >>> import torch
同构图
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4])) >>> g.ndata['h'] = torch.ones(5, 2) >>> g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) >>> g.edata['x'] tensor([[2., 2.], [2., 2.], [2., 2.], [2., 2.]])
使用内置函数
>>> import dgl.function as fn >>> g.apply_edges(fn.u_add_v('h', 'h', 'x')) >>> g.edata['x'] tensor([[2., 2.], [2., 2.], [2., 2.], [2., 2.]])
异构图
>>> g = dgl.heterograph({('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])}) >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5) >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2}) >>> g.edges[('user', 'plays', 'game')].data['h'] tensor([[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]])
另请参阅