基于Torch的特征
- class dgl.graphbolt.TorchBasedFeature(torch_feature: Tensor, metadata: Dict | None = None)[source]
基础类:
Feature
一个基于pytorch的特征包装器。
通过一个torch特征初始化一个基于torch的特征存储。 请注意,该特征可以在内存中或磁盘上。
- Parameters:
torch_feature (torch.Tensor) – 火炬特征。 请注意,张量的维度应大于1。
示例
>>> import torch >>> from dgl import graphbolt as gb
该功能在内存中。
>>> torch_feat = torch.arange(10).reshape(2, -1) >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read() tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> feature.read(torch.tensor([0])) tensor([[0, 1, 2, 3, 4]]) >>> feature.update(torch.tensor([[1 for _ in range(5)]]), ... torch.tensor([1])) >>> feature.read(torch.tensor([0, 1])) tensor([[0, 1, 2, 3, 4], [1, 1, 1, 1, 1]]) >>> feature.size() torch.Size([5])
该功能在磁盘上。
>>> import numpy as np >>> arr = np.array([[1, 2], [3, 4]]) >>> np.save("/tmp/arr.npy", arr) >>> torch_feat = torch.from_numpy(np.load("/tmp/arr.npy", mmap_mode="r+")) >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read() tensor([[1, 2], [3, 4]]) >>> feature.read(torch.tensor([0])) tensor([[1, 2]])
固定的CPU特性。
>>> torch_feat = torch.arange(10).reshape(2, -1).pin_memory() >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read().device device(type='cuda', index=0) >>> feature.read(torch.tensor([0]).cuda()).device device(type='cuda', index=0)
- read(ids: Tensor | None = None)[source]
通过索引读取特征。
如果特征位于固定的CPU内存中,并且ids位于GPU或固定的CPU内存中,它将被GPU读取,返回的张量将位于GPU上。否则,返回的张量将位于CPU上。
- Parameters:
ids (torch.Tensor, optional) – The index of the feature. If specified, only the specified indices of the feature are read. If None, the entire feature is returned.
- Returns:
读取功能。
- Return type:
torch.Tensor
- update(value: Tensor, ids: Tensor | None = None)[source]
更新特征存储。
- Parameters:
value (torch.Tensor) – The updated value of the feature.
ids (torch.Tensor, optional) – The indices of the feature to update. If specified, only the specified indices of the feature will be updated. For the feature, the ids[i] row is updated to value[i]. So the indices and value must have the same length. If None, the entire feature will be updated.