torch_geometric.Index
- class Index(data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any)[source]
Bases:
Tensor一个一维的
index张量,附加了额外的(元)数据。Index是一个torch.Tensor,它保存了形状为[num_indices]的索引。虽然
Index是torch.Tensor的一个子类,但它可以保存额外的(元)数据,即:此外,
Index通过indptr缓存数据,以便在其表示已排序的情况下快速进行 CSR 转换。 缓存是根据需求填充的(例如,当调用Index.get_indptr()时),或者通过Index.fill_cache_()明确请求时,并且在其生命周期内进行维护和调整。This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based PyG workflows.
from torch_geometric import Index index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) assert index.dim_size == 3 assert index.is_sorted # Flipping order: index.flip(0) >>> Index([[2, 1, 1, 0], dim_size=3) assert not index.is_sorted # Filtering: mask = torch.tensor([True, True, True, False]) index[:, mask] >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) assert index.is_sorted