torch_geometric.data.Database
- class Database(schema: ~typing.Union[~typing.Any, ~typing.Dict[str, ~typing.Any], ~typing.Tuple[~typing.Any], ~typing.List[~typing.Any]] = <class 'object'>)[source]
Bases:
ABC用于从数据库中插入和检索数据的基类。
数据库充当持久化、内存外和基于索引的键/值存储,用于张量和自定义数据:
db = Database() db[0] = Data(x=torch.randn(5, 16), y=0, z='id_0') print(db[0]) >>> Data(x=[5, 16], y=0, z='id_0')
为了提高效率,建议指定数据的底层
schema:db = Database(schema={ # Custom schema: # Tensor information can be specified through a dictionary: 'x': dict(dtype=torch.float, size=(-1, 16)), 'y': int, 'z': str, }) db[0] = dict(x=torch.randn(5, 16), y=0, z='id_0') print(db[0]) >>> {'x': torch.tensor(...), 'y': 0, 'z': 'id_0'}
此外,数据库支持批量插入和获取,并支持从索引Python列表中已知的语法糖,例如:
db = Database() db[2:5] = torch.randn(3, 16) print(db[torch.tensor([2, 3])]) >>> [torch.tensor(...), torch.tensor(...)]
- Parameters:
schema (Any or Tuple[Any] or Dict[str, Any], optional) – The schema of the input data. Can take
int,float,str,object, or a dictionary withdtypeandsizekeys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. (default:object)
- multi_insert(indices: Union[Sequence[int], Tensor, slice, range], data_list: Sequence[Any], batch_size: Optional[int] = None, log: bool = False) None[source]
在指定的索引处插入一块数据。
- Parameters:
indices (List[int] or torch.Tensor or range) – The indices at which to insert.
data_list (List[Any]) – The objects to insert.
batch_size (int, optional) – If specified, will insert the data to the database in batches of size
batch_size. (default:None)log (bool, optional) – If set to
True, will log progress to the console. (default:False)
- Return type: