项目集

class dgl.graphbolt.ItemSet(items: int | Tensor | Tuple[Tensor], names: str | Tuple[str] | None = None)[source]

Bases: object

张量或张量元组的包装器。

Parameters:
  • items (Union[int, torch.Tensor, Tuple[torch.Tensor]]) –

    要包装的张量。 - 如果它是一个单一的标量(一个整数或一个包含单个值的张量),该值将被视为由torch.arange创建的range_tensor。

    值),该值将被视为由torch.arange创建的range_tensor。

    • 如果它是一个多维张量,索引将沿着第一个维度进行。

    • 如果它是一个元组,元组中的每个项必须是一个张量。

  • names (Union[str, Tuple[str]], optional) – 项目的名称。如果是一个元组,每个名称必须对应于items参数中的一个项目。命名是任意的,但在一般实践中,名称应从[‘labels’, ‘seeds’, ‘indexes’]中选择,以与类dgl.graphbolt.MiniBatch的属性对齐。

示例

>>> import torch
>>> from dgl import graphbolt as gb
  1. 整数:节点数量。

>>> num = 10
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
 tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
('seeds',)
  1. Torch标量:节点数量。与Integer相比,可自定义dtype。

>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
 tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
 tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
 tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
 tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seeds',)
  1. 单个张量:种子节点。

>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names
('seeds',)
  1. 具有相同形状的张量元组:种子节点和标签。

>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet(
...     (node_ids, labels), names=("seeds", "labels"))
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
 (tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names
('seeds', 'labels')
  1. 具有不同形状的张量元组:种子和标签。

>>> seeds = torch.arange(0, 10).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0])
>>> item_set = gb.ItemSet(
...     (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1]), tensor([1])),
 (tensor([2, 3]), tensor([1])),
 (tensor([4, 5]), tensor([0])),
 (tensor([6, 7]), tensor([0])),
 (tensor([8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]),
 tensor([1, 1, 0, 0, 0]))
>>> item_set.names
('seeds', 'labels')
  1. 具有不同形状的张量元组:超链接和标签。

>>> seeds = torch.arange(0, 10).reshape(-1, 5)
>>> labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSet(
...     (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1, 2, 3, 4]), tensor([1])),
 (tensor([5, 6, 7, 8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
 tensor([1, 0]))
>>> item_set.names
('seeds', 'labels')
property names: Tuple[str]

返回项目的名称。

property num_items: int

返回项目的数量。