基于Torch的特征存储

class dgl.graphbolt.TorchBasedFeatureStore(feat_data: List[OnDiskFeatureData])[source]

基础类:BasicFeatureStore

一个用于管理多个基于pytorch的功能以供访问的存储。

特征存储由feat_data描述。feat_data是一个OnDiskFeatureData的列表。

对于一个特征存储,其格式必须为“pt”或“npy”,分别对应Pytorch或Numpy格式。如果格式为“pt”,则特征存储必须加载到内存中。如果格式为“npy”,则特征存储可以加载到内存中或磁盘上。

Parameters:

feat_data (List[OnDiskFeatureData]) – 特征存储的描述。

示例

>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> edge_label = torch.tensor([[1], [2], [3]])
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
>>> feat_data = [
...     gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
...         name="label", format="torch", path="/tmp/edge_label.pt",
...         in_memory=True),
...     gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
...         format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
is_pinned()[source]

如果所有存储的特征都被固定,则返回True。

pin_memory_()[source]

就地操作将特征存储复制到固定内存。 返回就地修改的相同对象。

to(device)[source]

TorchBasedFeatureStore 复制到指定设备。