Paddle 后端示例:匹配同构图

这个例子是对pygmtools的介绍,展示了如何匹配同构图。 同构图指的是结构相同但节点对应关系未知的图。

# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Qi Liu <purewhite@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

以下求解器支持QAP公式,并包含在此示例中:

import paddle # paddle backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
import warnings
warnings.filterwarnings("ignore")
pygm.set_backend('paddle') # set default backend for pygmtools

paddle.device.set_device('cpu')
_ = paddle.seed(1) # fix random seed

生成两个同构图

num_nodes = 10
X_gt = paddle.zeros((num_nodes, num_nodes))
X_gt[paddle.arange(0, num_nodes, dtype=paddle.int64), paddle.randperm(num_nodes)] = 1
A1 = paddle.rand((num_nodes, num_nodes))
A1 = (A1 + A1.t() > 1.) / 2 * (A1 + A1.t())
A1[paddle.arange(A1.shape[0]), paddle.arange(A1.shape[1])] = 0  # paddle.diagonal(A1)[:] = 0
A2 = paddle.mm(paddle.mm(X_gt.t(), A1), X_gt)
n1 = paddle.to_tensor([num_nodes])
n2 = paddle.to_tensor([num_nodes])

可视化图表

plt.figure(figsize=(8, 4))
G1 = nx.from_numpy_array(A1.numpy())
G2 = nx.from_numpy_array(A2.numpy())
pos1 = nx.spring_layout(G1)
pos2 = nx.spring_layout(G2)
plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)
Graph 1, Graph 2

这两个图看起来不同,因为它们没有对齐。然后我们通过图匹配来对齐这两个图。

构建亲和矩阵

为了通过图匹配来匹配同构图,我们遵循二次分配问题(QAP)的公式:

\[\begin{split}&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\ s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}\end{split}\]

其中第一步是构建亲和矩阵(\(\mathbf{K}\)

conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.1) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

亲和矩阵的可视化。对于具有\(N\)个节点的图匹配问题,亲和矩阵有\(N^2\times N^2\)个元素,因为每个图中有\(N^2\)条边。

注意

亲和矩阵的对角线元素为空,因为在这个示例中没有节点特征。

plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')
Affinity Matrix (size: 100$\times$100)
<matplotlib.image.AxesImage object at 0x7feb92da5180>

使用RRWM求解器解决图匹配问题

请参阅rrwm()以获取API参考。

X = pygm.rrwm(K, n1, n2)

RRWM 的输出是一个软匹配矩阵。可视化:

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')
RRWM Soft Matching Matrix, Ground Truth Matching Matrix
<matplotlib.image.AxesImage object at 0x7feb9296e230>

获取离散匹配矩阵

然后采用匈牙利算法来达到一个离散的匹配矩阵

X = pygm.hungarian(X)

离散匹配矩阵的可视化:

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={((X * X_gt).sum()/ X_gt.sum()).item():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')
RRWM Matching Matrix (acc=1.00), Ground Truth Matching Matrix
<matplotlib.image.AxesImage object at 0x7feb9296fd90>

对齐原始图形

绘制匹配(绿色线条表示正确匹配,红色线条表示错误匹配):

plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)
for i in range(num_nodes):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i, j] else "red")
    plt.gca().add_artist(con)
Graph 1, Graph 2

对齐节点:

align_A2 = paddle.mm(paddle.mm(X, A2), X.t())
plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Aligned Graph 2')
align_pos2 = {}
for i in range(num_nodes):
    j = paddle.argmax(X[i]).item()
    align_pos2[j] = pos1[i]
    con = ConnectionPatch(xyA=pos1[i], xyB=align_pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i, j] else "red")
    plt.gca().add_artist(con)
nx.draw_networkx(G2, pos=align_pos2)
Graph 1, Aligned Graph 2

其他求解器也可用

经典IPFP求解器

请参阅ipfp()以获取API参考。

X = pygm.ipfp(K, n1, n2)

IPFP匹配结果的可视化:

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'IPFP Matching Matrix (acc={((X * X_gt).sum()/ X_gt.sum()).item():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')
IPFP Matching Matrix (acc=1.00), Ground Truth Matching Matrix
<matplotlib.image.AxesImage object at 0x7feb927e6890>

经典SM求解器

请参阅sm()以获取API参考。

X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)

SM匹配结果的可视化:

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'SM Matching Matrix (acc={((X * X_gt).sum()/ X_gt.sum()).item():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')
SM Matching Matrix (acc=1.00), Ground Truth Matching Matrix
<matplotlib.image.AxesImage object at 0x7feb926af8e0>

NGM神经网络求解器

请参阅ngm()的API参考。

with paddle.set_grad_enabled(False):
    X = pygm.ngm(K, n1, n2, pretrain='voc')
    X = pygm.hungarian(X)

NGM匹配结果的可视化:

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'NGM Matching Matrix (acc={((X * X_gt).sum()/ X_gt.sum()).item():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')
NGM Matching Matrix (acc=1.00), Ground Truth Matching Matrix
<matplotlib.image.AxesImage object at 0x7feb9259aa10>

脚本总运行时间: (0 分钟 1.179 秒)

Gallery generated by Sphinx-Gallery