torch_geometric.data.Database

class Database(schema: ~typing.Union[~typing.Any, ~typing.Dict[str, ~typing.Any], ~typing.Tuple[~typing.Any], ~typing.List[~typing.Any]] = <class 'object'>)[source]

Bases: ABC

用于从数据库中插入和检索数据的基类。

数据库充当持久化、内存外和基于索引的键/值存储,用于张量和自定义数据:

db = Database()
db[0] = Data(x=torch.randn(5, 16), y=0, z='id_0')
print(db[0])
>>> Data(x=[5, 16], y=0, z='id_0')

为了提高效率,建议指定数据的底层schema

db = Database(schema={  # Custom schema:
    # Tensor information can be specified through a dictionary:
    'x': dict(dtype=torch.float, size=(-1, 16)),
    'y': int,
    'z': str,
})
db[0] = dict(x=torch.randn(5, 16), y=0, z='id_0')
print(db[0])
>>> {'x': torch.tensor(...), 'y': 0, 'z': 'id_0'}

此外,数据库支持批量插入和获取,并支持从索引列表中已知的语法糖,例如

db = Database()
db[2:5] = torch.randn(3, 16)
print(db[torch.tensor([2, 3])])
>>> [torch.tensor(...), torch.tensor(...)]
Parameters:

schema (Any or Tuple[Any] or Dict[str, Any], optional) – The schema of the input data. Can take int, float, str, object, or a dictionary with dtype and size keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. (default: object)

connect() None[source]

连接到数据库。 数据库在实例化时会自动连接。

Return type:

None

close() None[source]

关闭与数据库的连接。

Return type:

None

abstract insert(index: int, data: Any) None[source]

在指定索引处插入数据。

Parameters:
  • index (int) – The index at which to insert.

  • data (Any) – The object to insert.

Return type:

None

multi_insert(indices: Union[Sequence[int], Tensor, slice, range], data_list: Sequence[Any], batch_size: Optional[int] = None, log: bool = False) None[source]

在指定的索引处插入一块数据。

Parameters:
  • indices (List[int] or torch.Tensor or range) – The indices at which to insert.

  • data_list (List[Any]) – The objects to insert.

  • batch_size (int, optional) – If specified, will insert the data to the database in batches of size batch_size. (default: None)

  • log (bool, optional) – If set to True, will log progress to the console. (default: False)

Return type:

None

abstract get(index: int) Any[source]

从指定的索引获取数据。

Parameters:

index (int) – The index to query.

Return type:

Any

multi_get(indices: Union[Sequence[int], Tensor, slice, range], batch_size: Optional[int] = None) List[Any][source]

从指定的索引获取一块数据。

Parameters:
  • indices (List[int] or torch.Tensor or range) – The indices to query.

  • batch_size (int, optional) – If specified, will request the data from the database in batches of size batch_size. (default: None)

Return type:

List[Any]