注意
转到末尾 以下载完整的示例代码
Numpy 后端示例:通过 QAP 求解器匹配图像关键点
此示例展示了如何使用pygmtools提供的图匹配求解器来匹配图像关键点。
这些求解器遵循二次分配问题的公式,通常可以即插即用。
匹配后的图像可以进一步处理以用于其他下游任务。
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
# Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
import numpy as np # numpy backend
import cv2 as cv
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
from PIL import Image
pygm.set_backend('numpy') # set numpy as backend for pygmtools
加载图片
图像来自Willow对象类数据集(该数据集也可在pygmtools的基准测试中使用,参见WillowObject)。
图像被调整为256x256。
obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = np.array(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = np.array(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
可视化图像和关键点
def plot_image_with_graph(img, kpt, A=None):
plt.imshow(img)
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
if A is not None:
for x, y in zip(np.nonzero(A)[0], np.nonzero(A)[1]):
plt.plot((kpt[0, x], kpt[0, y]), (kpt[1, x], kpt[1, y]), 'k-')
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)

构建图表
图结构是基于关键点集的几何结构构建的。在这个例子中,我们参考了Delaunay三角剖分。
def delaunay_triangulation(kpt):
d = spa.Delaunay(kpt.T)
A = np.zeros((len(kpt[0]), len(kpt[0])))
for simplex in d.simplices:
for pair in itertools.permutations(simplex, 2):
A[pair] = 1
return A
A1 = delaunay_triangulation(kpts1)
A2 = delaunay_triangulation(kpts2)
我们将边的长度编码为边特征
A1 = ((np.expand_dims(kpts1, 1) - np.expand_dims(kpts1, 2)) ** 2).sum(axis=0) * A1
A1 = (A1 / A1.max()).astype(np.float32)
A2 = ((np.expand_dims(kpts2, 1) - np.expand_dims(kpts2, 2)) ** 2).sum(axis=0) * A2
A2 = (A2 / A2.max()).astype(np.float32)
可视化图表
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with Graphs')
plot_image_with_graph(img1, kpts1, A1)
plt.subplot(1, 2, 2)
plt.title('Image 2 with Graphs')
plot_image_with_graph(img2, kpts2, A2)

提取节点特征
让我们采用SIFT方法来提取节点特征。
np_img1 = np.array(img1, dtype=np.float32)
np_img2 = np.array(img2, dtype=np.float32)
def detect_sift(img):
sift = cv.SIFT_create()
gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
img8bit = cv.normalize(gray, None, 0, 255, cv.NORM_MINMAX).astype('uint8')
kpt = sift.detect(img8bit, None)
kpt, feat = sift.compute(img8bit, kpt)
return kpt, feat
sift_kpts1, feat1 = detect_sift(np_img1)
sift_kpts2, feat2 = detect_sift(np_img2)
sift_kpts1 = np.round(cv.KeyPoint_convert(sift_kpts1).T).astype(int)
sift_kpts2 = np.round(cv.KeyPoint_convert(sift_kpts2).T).astype(int)
标准化特征
num_features = feat1.shape[1]
feat1 = feat1 / np.expand_dims(np.linalg.norm(feat1, axis=1), 1).repeat(128, axis=1)
feat2 = feat2 / np.expand_dims(np.linalg.norm(feat2, axis=1), 1).repeat(128, axis=1)
通过最近邻插值提取节点特征
rounded_kpts1 = np.round(kpts1).astype(int)
rounded_kpts2 = np.round(kpts2).astype(int)
idx_1, idx_2 = [], []
for i in range(rounded_kpts1.shape[1]):
y1 = np.where(sift_kpts1[1] == sift_kpts1[1][np.abs(sift_kpts1[1] - rounded_kpts1[1][i]).argmin()])
y2 = np.where(sift_kpts2[1] == sift_kpts2[1][np.abs(sift_kpts2[1] - rounded_kpts2[1][i]).argmin()])
t1 = sift_kpts1[0][y1]
t2 = sift_kpts2[0][y2]
x1 = np.where(sift_kpts1[0] == t1[np.abs(t1 - rounded_kpts1[0][i]).argmin()])
x2 = np.where(sift_kpts2[0] == t2[np.abs(t2 - rounded_kpts2[0][i]).argmin()])
idx_1.append(np.intersect1d(x1, y1)[0])
idx_2.append(np.intersect1d(x2, y2)[0])
node1 = feat1[idx_1, :] # shape: NxC
node2 = feat2[idx_2, :] # shape: NxC
构建亲和矩阵
我们遵循二次分配问题(QAP)的公式:
其中第一步是构建亲和矩阵(\(\mathbf{K}\))
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1) # set affinity function
K = pygm.utils.build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, edge_aff_fn=gaussian_aff)
亲和矩阵的可视化。对于具有\(N\)个节点的图匹配问题,亲和矩阵有\(N^2\times N^2\)个元素,因为每个图中有\(N^2\)条边。
注意
对角线元素是节点亲和力,非对角线元素是边缘特征。
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K, cmap='Blues')

<matplotlib.image.AxesImage object at 0x7fd8784e31c0>
使用RRWM求解器解决图匹配问题
请参阅rrwm()以获取API参考。
X = pygm.rrwm(K, kpts1.shape[1], kpts2.shape[1])
RRWM的输出是一个软匹配矩阵。然后采用匈牙利算法来达到一个离散的匹配矩阵。
X = pygm.hungarian(X)
绘制匹配
正确的匹配用绿色标记,错误的匹配用红色标记。在这个例子中,节点按照它们的真实类别排序(即真实匹配矩阵是一个对角矩阵)。
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = np.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

通过其他求解器解决
我们也可以在这个特定问题上对其他求解器进行快速基准测试。
IPFP求解器
请参阅ipfp()以获取API参考。
X = pygm.ipfp(K, kpts1.shape[1], kpts2.shape[1])
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPFP')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = np.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

/home/wzever/pygmtools/pygmtools/numpy_backend.py:304: RuntimeWarning: invalid value encountered in divide
t0 = alpha / beta
SM求解器
请参阅sm()以获取API参考。
X = pygm.sm(K, kpts1.shape[1], kpts2.shape[1])
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by SM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = np.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

NGM 求解器
请参阅ngm()的API参考。
注意
NGM求解器是在不同的问题设置上预训练的,因此它们的性能可能看起来较差。 为了提高它们的性能,您可以改变构建亲和矩阵的方式,或者尝试在新的问题上微调NGM。
NGM求解器在Willow数据集上预训练:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='willow')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (willow pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = np.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

NGM 求解器在 VOC 数据集上预训练:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (voc pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = np.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

脚本总运行时间: (0 分钟 2.166 秒)