热核

class dgl.transforms.HeatKernel(t=2.0, eweight_name='w', eps=None, avg_degree=5)[source]

Bases: BaseTransform

将热核应用于输入图以进行扩散,如图和其他离散结构上的扩散核中所介绍。

扩散后将应用于加权邻接矩阵的稀疏化。 具体来说,权重低于阈值的边将被删除。

该模块仅适用于同构图。

Parameters:
  • t (float, optional) – 扩散时间,通常位于 \([2, 10]\)

  • eweight_name (str, optional) – edata 名称,用于检索和存储边的权重。如果在输入图中不存在,此模块会为所有边初始化权重为1。边的权重应该是一个形状为 \((E)\) 的张量,其中 E 是边的数量。

  • eps (float, optional) – 扩散后稀疏化过程中保留边缘的阈值。权重小于 eps 的边缘将被丢弃。

  • avg_degree (int, optional) – 结果图的期望平均节点度数。这是控制结果图稀疏性的另一种方式,只有在未提供eps时才会生效。

示例

>>> import dgl
>>> import torch
>>> from dgl import HeatKernel
>>> transform = HeatKernel(avg_degree=2)
>>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))
>>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
>>> new_g = transform(g)
>>> print(new_g.edata['w'])
tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353,
        0.1083, 0.0541, 0.1353])