项目集
- class dgl.graphbolt.ItemSet(items: int | Tensor | Tuple[Tensor], names: str | Tuple[str] | None = None)[source]
Bases:
object
张量或张量元组的包装器。
- Parameters:
items (Union[int, torch.Tensor, Tuple[torch.Tensor]]) –
要包装的张量。 - 如果它是一个单一的标量(一个整数或一个包含单个值的张量),该值将被视为由torch.arange创建的range_tensor。
值),该值将被视为由torch.arange创建的range_tensor。
如果它是一个多维张量,索引将沿着第一个维度进行。
如果它是一个元组,元组中的每个项必须是一个张量。
names (Union[str, Tuple[str]], optional) – 项目的名称。如果是一个元组,每个名称必须对应于items参数中的一个项目。命名是任意的,但在一般实践中,名称应从[‘labels’, ‘seeds’, ‘indexes’]中选择,以与类dgl.graphbolt.MiniBatch的属性对齐。
示例
>>> import torch >>> from dgl import graphbolt as gb
整数:节点数量。
>>> num = 10 >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> item_set.names ('seeds',)
Torch标量:节点数量。与Integer相比,可自定义dtype。
>>> num = torch.tensor(10, dtype=torch.int32) >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32), tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32), tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32) >>> item_set.names ('seeds',)
单个张量:种子节点。
>>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)] >>> item_set[:] tensor([0, 1, 2, 3, 4]) >>> item_set.names ('seeds',)
具有相同形状的张量元组:种子节点和标签。
>>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) >>> item_set = gb.ItemSet( ... (node_ids, labels), names=("seeds", "labels")) >>> list(item_set) [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)), (tensor(3), tensor(8)), (tensor(4), tensor(9))] >>> item_set[:] (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])) >>> item_set.names ('seeds', 'labels')
具有不同形状的张量元组:种子和标签。
>>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1]), tensor([1])), (tensor([2, 3]), tensor([1])), (tensor([4, 5]), tensor([0])), (tensor([6, 7]), tensor([0])), (tensor([8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]), tensor([1, 1, 0, 0, 0])) >>> item_set.names ('seeds', 'labels')
具有不同形状的张量元组:超链接和标签。
>>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1, 2, 3, 4]), tensor([1])), (tensor([5, 6, 7, 8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), tensor([1, 0])) >>> item_set.names ('seeds', 'labels')