torch_geometric.data.Batch
- class Batch(*args: Any, **kwargs: Any)[source]
基础类:
object描述一批图作为一个大(不连通)图的数据对象。 继承自
torch_geometric.data.Data或torch_geometric.data.HeteroData。 此外,单个图可以通过分配向量batch来识别,该向量将每个节点映射到其相应的图标识符。PyG 允许通过重写
__inc__()和__cat_dim__()功能来修改底层批处理过程。__inc__()方法定义了两个连续图属性之间的增量计数。 默认情况下,PyG 在属性名称包含子字符串index时(由于历史原因),会按节点数增加属性,这对于诸如edge_index或node_index等属性非常有用。 然而,请注意,这可能会导致名称包含子字符串index但不应该增加的属性出现意外行为。 为了确保,最佳实践是始终仔细检查批处理的输出。 此外,__cat_dim__()定义了相同属性的图张量应在哪个维度上连接在一起。- classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None) Self[source]
从
Data或HeteroData对象的列表中构造一个Batch对象。 分配向量batch是动态创建的。 此外,为follow_batch中的每个键创建分配向量。 将排除exclude_keys中给出的任何键。- Return type:
Self
- get_example(idx: int) BaseData[source]
获取索引
idx处的Data或HeteroData对象。Batch对象必须通过from_data_list()创建,以便能够重建初始对象。- Return type:
BaseData
- index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData][source]
从指定的索引
idx创建Data或HeteroData对象的子集。 索引idx可以是一个切片对象,例如,[2:5],一个 列表,一个元组,或者一个torch.Tensor或np.ndarray类型 为 long 或 bool。Batch对象必须通过from_data_list()创建,以便能够重建 初始对象。- Return type:
List[BaseData]
- to_data_list() List[BaseData][source]
从
Batch对象中重建Data或HeteroData对象的列表。Batch对象必须通过from_data_list()创建,以便能够重建初始对象。- Return type:
List[BaseData]
- property batch_size: int
num_graphs的别名。- Return type: