dgl.graphbolt.compact_csc_format
- dgl.graphbolt.compact_csc_format(csc_formats: CSCFormatBase | Dict[str, CSCFormatBase], dst_nodes: Tensor | Dict[str, Tensor], dst_timestamps: Tensor | Dict[str, Tensor] | None = None)[source]
将csc格式中的行(源)ID重新标记为从0开始的连续范围,并返回每种类型的原始行节点ID。
请注意 1. 列(目的地)ID包含在重新标记的行ID中。 2. 如果有重复的行ID,它们不会被唯一化,并将被视为不同的节点。 3. 如果给出了dst_timestamps,每个目的地节点的时间戳将广播到其对应的源节点。
- Parameters:
csc_formats (Union[CSCFormatBase, Dict[str, CSCFormatBase]]) – 表示源-目标边的CSC格式。 - 如果 csc_formats 是 CSCFormatBase:表示图是同质的。此外,其中的 indptr 和 indice 应该是 torch.tensor,表示 csc 格式中的源和目标对。并且其中的 ID 是同质 ID。 - 如果 csc_formats 是 Dict[str, CSCFormatBase]:键应该是边类型,值应该是 csc 格式的节点对。并且其中的 ID 是异质 ID。
dst_nodes (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 节点对中所有目标节点的节点。 - 如果 dst_nodes 是一个张量:表示图是同质的。 - 如果 dst_nodes 是一个字典:键是节点类型,值是对应的节点。内部的ID是异质ID。
dst_timestamps (可选[联合[torch.Tensor, 字典[str, torch.Tensor]]]) – 所有目标节点在csc格式中的时间戳。 如果提供,每个目标节点的时间戳将广播到其对应的源节点。
- Returns:
输入中所有节点的原始行节点ID(按类型)的张量。 压缩的CSC格式,其中节点ID被替换为从0到N的映射节点ID。 如果给出了dst_timestamps,则输入中所有节点的源时间戳(按类型)。
- Return type:
元组[原始行节点ID,压缩的CSC格式,…]
示例
>>> import dgl.graphbolt as gb >>> csc_formats = { ... "n2:e2:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6]) ... ), ... "n1:e1:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3]) ... ), ... } >>> dst_nodes = {"n1": torch.LongTensor([2, 4])} >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format( ... csc_formats, dst_nodes ... ) >>> original_row_node_ids {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])} >>> compacted_csc_formats {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([0, 1, 2]), ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([2, 3, 4]), )}
>>> csc_formats = { ... "n2:e2:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6]) ... ), ... "n1:e1:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3]) ... ), ... } >>> dst_nodes = {"n1": torch.LongTensor([2, 4])} >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format( ... csc_formats, dst_nodes ... ) >>> original_row_node_ids {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])} >>> compacted_csc_formats {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([0, 1, 2]), ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([2, 3, 4]), )}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])} >>> ( ... original_row_node_ids, ... compacted_csc_formats, ... src_timestamps, ... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps) >>> src_timestamps {'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}