dgl.udf.NodeBatch.nodes
- NodeBatch.nodes()[source]
返回批次中的节点。
- Returns:
NID – 批次中节点的ID。\(NID[i]\) 给出第i个节点的ID。
- Return type:
张量
示例
以下示例使用PyTorch后端。
>>> import dgl >>> import torch
>>> # Instantiate a graph and set a feature 'h'. >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and >>> # the original ID for each node. >>> def node_udf(nodes): >>> # nodes.nodes() is a tensor of shape (N), >>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1), >>> # where N is the number of nodes in the batch, D is the number >>> # of messages received per node for this node batch. >>> return {'h': nodes.nodes().unsqueeze(-1).float() >>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing. >>> import dgl.function as fn >>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.ndata['h'] tensor([[1.], [3.]])