高级小批量处理
创建小批量对于让深度学习模型的训练扩展到大量数据至关重要。
与逐个处理示例不同,小批量将一组示例分组为一个统一的表示,从而可以高效地并行处理。
在图像或语言领域,通常通过将每个示例重新缩放或填充为一组相同大小的形状来实现此过程,然后将示例在额外的维度中分组。
此维度的长度等于小批量中分组的示例数量,通常称为batch_size。
由于图是能够包含任何数量节点或边的最通用的数据结构之一,上述两种方法要么不可行,要么可能导致大量不必要的内存消耗。 在PyG中,我们选择了另一种方法来实现跨多个示例的并行化。 在这里,邻接矩阵以对角线方式堆叠(创建一个包含多个孤立子图的大图),节点和目标特征在节点维度上简单地连接,即
此过程与其他批处理过程相比具有一些关键优势:
依赖消息传递方案的GNN操作符不需要修改,因为属于不同图的两个节点之间仍然无法交换消息。
没有计算或内存开销。 例如,这个批处理过程完全不需要对节点或边特征进行任何填充。 请注意,由于邻接矩阵以稀疏方式保存,仅保存非零条目,即边,因此没有额外的内存开销。
PyG 自动通过 torch_geometric.loader.DataLoader 类将多个图批量处理为一个巨大的图。
在内部,DataLoader 只是一个普通的 PyTorch torch.utils.data.DataLoader,它覆盖了其 collate() 功能,即,定义了如何将一组示例分组在一起。
因此,所有可以传递给 PyTorch DataLoader 的参数也可以传递给 PyG DataLoader,例如,工作线程数 num_workers。
在最一般的形式中,PyG DataLoader 会自动增加 edge_index 张量,增加的量是当前处理的图之前所有图的累积节点数,并且会在第二维度上连接 edge_index 张量(形状为 [2, num_edges])。
对于 face 张量也是如此,即,网格中的面索引。
所有其他张量将在第一维度上连接,而不会进一步增加它们的值。
然而,有一些特殊的用例(如下所述)用户希望根据自己的需求主动修改这种行为。
PyG 允许通过覆盖 torch_geometric.data.Data.__inc__() 和 torch_geometric.data.Data.__cat_dim__() 功能来修改底层批处理过程。
在没有任何修改的情况下,这些功能在 Data 类中定义如下:
def __inc__(self, key, value, *args, **kwargs):
if 'index' in key:
return self.num_nodes
else:
return 0
def __cat_dim__(self, key, value, *args, **kwargs):
if 'index' in key:
return 1
else:
return 0
我们可以看到__inc__()定义了两个连续图属性之间的增量计数。
默认情况下,PyG会在属性名称包含子字符串index时(由于历史原因),按节点数递增属性,这对于诸如edge_index或node_index等属性非常有用。
然而,请注意,这可能会导致名称包含子字符串index但不应该递增的属性的意外行为。
为了确保,最佳实践是始终仔细检查批处理的输出。
此外,__cat_dim__()定义了相同属性的图张量应在哪个维度上连接在一起。
这两个函数都会为存储在Data类中的每个属性调用,并传递它们特定的key和值item作为参数。
接下来,我们展示一些用例,在这些用例中,修改__inc__()和__cat_dim__()可能是绝对必要的。
图形对
如果你想在单个Data对象中存储多个图,例如,用于图匹配等应用,你需要确保所有这些图的正确批处理行为。
例如,考虑在Data中存储两个图,一个源图\(\mathcal{G}_s\)和一个目标图\(\mathcal{G}_t\),例如:
from torch_geometric.data import Data
class PairData(Data):
pass
data = PairData(x_s=x_s, edge_index_s=edge_index_s, # Source graph.
x_t=x_t, edge_index_t=edge_index_t) # Target graph.
在这种情况下,edge_index_s 应该增加源图 \(\mathcal{G}_s\) 中的节点数量,例如,x_s.size(0),而 edge_index_t 应该增加目标图 \(\mathcal{G}_t\) 中的节点数量,例如,x_t.size(0):
class PairData(Data):
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
return super().__inc__(key, value, *args, **kwargs)
我们可以通过设置一个简单的测试脚本来测试我们的PairData批处理行为:
from torch_geometric.loader import DataLoader
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_t = torch.randn(4, 16) # 4 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
data = PairData(x_s=x_s, edge_index_s=edge_index_s,
x_t=x_t, edge_index_t=edge_index_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8],
x_t=[8, 16], edge_index_t=[2, 6])
print(batch.edge_index_s)
>>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
[1, 2, 3, 4, 6, 7, 8, 9]])
print(batch.edge_index_t)
>>> tensor([[0, 0, 0, 4, 4, 4],
[1, 2, 3, 5, 6, 7]])
到目前为止一切看起来都很好!
edge_index_s 和 edge_index_t 即使在使用不同数量的节点时也能正确地批量处理在一起,即使对于 \(\mathcal{G}_s\) 和 \(\mathcal{G}_t\) 也是如此。
然而,由于 PyG 无法识别 PairData 对象中的实际图,因此缺少 batch 属性(该属性将每个节点映射到其相应的图)。
这就是 DataLoader 的 follow_batch 参数发挥作用的地方。
在这里,我们可以指定我们希望为哪些属性维护批量信息:
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))
print(batch)
>>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_s_batch=[10],
x_t=[8, 16], edge_index_t=[2, 6], x_t_batch=[8])
print(batch.x_s_batch)
>>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print(batch.x_t_batch)
>>> tensor([0, 0, 0, 0, 1, 1, 1, 1])
正如我们所看到的,follow_batch=['x_s', 'x_t'] 现在成功地创建了节点特征 x_s 和 x_t 的分配向量 x_s_batch 和 x_t_batch。
这些信息现在可以用于在单个 Batch 对象中对多个图执行归约操作,例如全局池化。
二分图
二分图的邻接矩阵定义了两种不同节点类型之间的关系。
通常情况下,每种节点类型的节点数量不需要匹配,导致邻接矩阵的形状为非二次的\(\mathbf{A} \in \{ 0, 1 \}^{N \times M}\),其中\(N \neq M\)是可能的。
在二分图的小批量处理过程中,edge_index中边的源节点应该以不同于edge_index中边的目标节点的方式增加。
为了实现这一点,考虑一个具有相应节点特征x_s和x_t的两种节点类型之间的二分图:
from torch_geometric.data import Data
class BipartiteData(Data):
pass
data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)
为了在二分图中正确进行小批量处理,我们需要告诉PyG,它应该独立地增加edge_index中边的源节点和目标节点:
class BipartiteData(Data):
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index':
return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
return super().__inc__(key, value, *args, **kwargs)
在这里,edge_index[0](边的源节点)增加了x_s.size(0),而edge_index[1](边的目标节点)增加了x_t.size(0)。
我们可以通过运行一个简单的测试脚本来再次测试我们的实现:
from torch_geometric.loader import DataLoader
x_s = torch.randn(2, 16) # 2 nodes.
x_t = torch.randn(3, 16) # 3 nodes.
edge_index = torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 2],
])
data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> BipartiteDataBatch(x_s=[4, 16], x_t=[6, 16], edge_index=[2, 8])
print(batch.edge_index)
>>> tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 1, 1, 2, 3, 4, 4, 5]])
再次强调,这正是我们追求的行为!
沿新维度批处理
有时,data 对象的属性应该通过获得一个新的批次维度来进行批处理(如在经典的小批量处理中),例如,对于图级别的属性或目标。
具体来说,形状为 [num_features] 的属性列表应该返回为 [num_examples, num_features] 而不是 [num_examples * num_features]。
PyG 通过在 __cat_dim__() 中返回一个 None 的拼接维度来实现这一点:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
class MyData(Data):
def __cat_dim__(self, key, value, *args, **kwargs):
if key == 'foo':
return None
return super().__cat_dim__(key, value, *args, **kwargs)
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1],
])
foo = torch.randn(16)
data = MyData(num_nodes=3, edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16])
如所期望的,batch.foo 现在由两个维度描述:批次维度和特征维度。