注意
转到末尾 以下载完整的示例代码
PyTorch 后端示例:发现子图
此示例展示了如何将较小的图匹配到较大图的子集。
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
import torch # pytorch backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.set_backend('pytorch') # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed
生成更大的图
num_nodes2 = 10
A2 = torch.rand(num_nodes2, num_nodes2)
A2 = (A2 + A2.t() > 1.) * (A2 + A2.t()) / 2
torch.diagonal(A2)[:] = 0
n2 = torch.tensor([num_nodes2])
生成较小的图
num_nodes1 = 5
G2 = nx.from_numpy_array(A2.numpy())
pos2 = nx.spring_layout(G2)
pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)])
selected = [0] # build G1 as a cluster in visualization
unselected = list(range(1, num_nodes2))
while len(selected) < num_nodes1:
dist = torch.sum(torch.sum(torch.abs(pos2_t[selected].unsqueeze(1) - pos2_t[unselected].unsqueeze(0)), dim=-1), dim=0)
select_id = unselected[torch.argmin(dist).item()] # find the closest node from unselected
selected.append(select_id)
unselected.remove(select_id)
selected.sort()
A1 = A2[selected, :][:, selected]
X_gt = torch.eye(num_nodes2)[selected, :]
n1 = torch.tensor([num_nodes1])
/home/wzever/pygmtools/examples/3.discovering_subgraphs/plot_subgraphs_pytorch.py:52: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:210.)
pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)])
可视化图表
G1 = nx.from_numpy_array(A1.numpy())
pos1 = {_: pos2[selected[_]] for _ in range(num_nodes1)}
color1 = ['#FF5733' for _ in range(num_nodes1)]
color2 = ['#FF5733' if _ in selected else '#1f78b4' for _ in range(num_nodes2)]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)

然后我们展示了如何通过图匹配自动发现匹配。
构建亲和矩阵
为了匹配较大的图和较小的图,我们遵循二次分配问题(QAP)的公式:
\[\begin{split}&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\
s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}\end{split}\]
其中第一步是构建亲和矩阵(\(\mathbf{K}\))
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.001) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)
亲和矩阵的可视化。对于具有\(N_1\)和\(N_2\)节点的图匹配问题,亲和矩阵有\(N_1N_2\times N_1N_2\)个元素,因为每个图中分别有\(N_1^2\)和\(N_2^2\)条边。
注意
亲和矩阵的对角线元素为空,因为在这个示例中没有节点特征。
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7fd306a0c280>
使用RRWM求解器解决图匹配问题
请参阅rrwm()以获取API参考。
X = pygm.rrwm(K, n1, n2)
RRWM 的输出是一个软匹配矩阵。可视化:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7fd306ac0fa0>
获取离散匹配矩阵
然后采用匈牙利算法来达到一个离散的匹配矩阵
X = pygm.hungarian(X)
离散匹配矩阵的可视化:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7fd307ef1480>
匹配子图
绘制匹配:
plt.figure(figsize=(8, 4))
plt.suptitle(f'RRWM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

其他求解器也可用
经典IPFP求解器
请参阅ipfp()以获取API参考。
X = pygm.ipfp(K, n1, n2)
IPFP匹配结果的可视化:
plt.figure(figsize=(8, 4))
plt.suptitle(f'IPFP Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

经典SM求解器
请参阅sm()以获取API参考。
X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)
SM匹配结果的可视化:
plt.figure(figsize=(8, 4))
plt.suptitle(f'SM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

NGM神经网络求解器
请参阅ngm()的API参考。
注意
NGM求解器是在不同的问题设置上预训练的,因此它们的性能可能看起来较差。 为了提高它们的性能,您可以改变构建亲和矩阵的方式,或者尝试在新的问题上微调NGM。
with torch.set_grad_enabled(False):
X = pygm.ngm(K, n1, n2, pretrain='voc')
X = pygm.hungarian(X)
NGM匹配结果的可视化:
plt.figure(figsize=(8, 4))
plt.suptitle(f'NGM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

脚本总运行时间: (0 分钟 0.545 秒)