torch_geometric.data.Batch

class Batch(*args: Any, **kwargs: Any)[source]

基础类:object

描述一批图作为一个大(不连通)图的数据对象。 继承自 torch_geometric.data.Datatorch_geometric.data.HeteroData。 此外,单个图可以通过分配向量 batch 来识别,该向量将每个节点映射到其相应的图标识符。

允许通过重写 __inc__()__cat_dim__() 功能来修改底层批处理过程。 __inc__() 方法定义了两个连续图属性之间的增量计数。 默认情况下, 在属性名称包含子字符串 index 时(由于历史原因),会按节点数增加属性,这对于诸如 edge_indexnode_index 等属性非常有用。 然而,请注意,这可能会导致名称包含子字符串 index 但不应该增加的属性出现意外行为。 为了确保,最佳实践是始终仔细检查批处理的输出。 此外,__cat_dim__() 定义了相同属性的图张量应在哪个维度上连接在一起。

classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None) Self[source]

DataHeteroData对象的列表中构造一个Batch对象。 分配向量batch是动态创建的。 此外,为follow_batch中的每个键创建分配向量。 将排除exclude_keys中给出的任何键。

Return type:

Self

get_example(idx: int) BaseData[source]

获取索引 idx 处的 DataHeteroData 对象。 Batch 对象必须通过 from_data_list() 创建,以便能够重建初始对象。

Return type:

BaseData

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData][source]

从指定的索引 idx 创建 DataHeteroData 对象的子集。 索引 idx 可以是一个切片对象,例如[2:5],一个 列表,一个元组,或者一个 torch.Tensornp.ndarray 类型 为 long 或 bool。 Batch 对象必须通过 from_data_list() 创建,以便能够重建 初始对象。

Return type:

List[BaseData]

to_data_list() List[BaseData][source]

Batch对象中重建DataHeteroData对象的列表。 Batch对象必须通过from_data_list()创建,以便能够重建初始对象。

Return type:

List[BaseData]

property num_graphs: int

返回批次中的图的数量。

Return type:

int

property batch_size: int

num_graphs 的别名。

Return type:

int