CuGraphRelGraphConv
- class dgl.nn.pytorch.conv.CuGraphRelGraphConv(in_feat, out_feat, num_rels, regularizer=None, num_bases=None, bias=True, self_loop=True, dropout=0.0, apply_norm=False)[source]
Bases:
CuGraphBaseConv
一个加速的关系图卷积层,来自使用图卷积网络建模关系数据,它利用了cugraph-ops中高度优化的聚合原语。
请参阅
dgl.nn.pytorch.conv.RelGraphConv
以了解数学模型。This module depends on
pylibcugraphops
package, which can be installed viaconda install -c nvidia pylibcugraphops=23.04
.pylibcugraphops
23.04 requires python 3.8.x or 3.10.x.注意
This is an experimental feature.
- Parameters:
in_feat (int) – 输入特征大小。
out_feat (int) – 输出特征大小。
num_rels (int) – 关系的数量。
regularizer (str, optional) –
- 使用哪种权重正则化器(“basis” 或
None
): “basis” 用于基础分解。
None
表示不应用正则化。
默认值:
None
。- 使用哪种权重正则化器(“basis” 或
num_bases (int, optional) – 基础数量。当应用正则化器时生效。 默认值:
None
.偏差 (bool, 可选) – 如果添加了偏差则为True。默认值:
True
。self_loop (bool, 可选) – 如果为True,则包含自环消息。默认值:
True
。dropout (float, optional) – 丢弃率。默认值:
0.0
。apply_norm (bool, optional) – 如果为True,则按每条边类型的目标节点的入度对聚合输出进行归一化,即 \(|\mathcal{N}^r_i|\)。默认值:
True
。
示例
>>> import dgl >>> import torch >>> from dgl.nn import CuGraphRelGraphConv ... >>> device = 'cuda' >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device) >>> feat = torch.ones(6, 10).to(device) >>> conv = CuGraphRelGraphConv( ... 10, 2, 3, regularizer='basis', num_bases=2).to(device) >>> etype = torch.tensor([0,1,2,0,1,2]).to(device) >>> res = conv(g, feat, etype) >>> res tensor([[-1.7774, -2.0184], [-1.4335, -2.3758], [-1.7774, -2.0184], [-0.4698, -3.0876], [-1.4335, -2.3758], [-1.4331, -2.3295]], device='cuda:0', grad_fn=<AddBackward0>)
- forward(g, feat, etypes, max_in_degree=None)[source]
前向计算。
- Parameters:
g (DGLGraph) – The graph.
特征 (torch.Tensor) – 一个2D张量的节点特征。形状:\((|V|, D_{in})\)。
etypes (torch.Tensor) – 一个一维整数张量,表示边的类型。形状为:\((|E|,)\)。 请注意,cugraph-ops 只接受 int32 类型的边类型张量, 因此任何其他整数类型的输入都将被强制转换为 int32, 从而引入一些开销。为了获得最佳性能,请直接传入 int32 张量。
max_in_degree (int, optional) – 目标节点的最大入度。只有当
g
是DGLBlock
时,即二分图时,才有效。当g
是从邻居采样器生成时,该值应设置为相应的fanout
。如果未给出,max_in_degree
将动态计算。
- Returns:
新节点特征。形状:\((|V|, D_{out})\)。
- Return type:
torch.Tensor