Paddle 后端示例:通过 QAP 求解器匹配图像关键点

此示例展示了如何使用pygmtools提供的图匹配求解器来匹配图像关键点。 这些求解器遵循二次分配问题的公式,通常可以即插即用。 匹配后的图像可以进一步处理以用于其他下游任务。

# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

以下求解器支持QAP公式,并包含在此示例中:

import paddle # pypaddle backend
from paddle.vision.models import vgg16
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
import numpy as np
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
pygm.set_backend('paddle') # set default backend for pygmtools

加载图片

图像来自Willow对象类数据集(该数据集也可在pygmtools的基准测试中使用,参见WillowObject)。

图像被调整为256x256。

obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = paddle.to_tensor(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = paddle.to_tensor(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)

可视化图像和关键点

def plot_image_with_graph(img, kpt, A=None):
    plt.imshow(img)
    plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
    if A is not None:
        for idx in paddle.nonzero(A, as_tuple=False):
            plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)
Image 1, Image 2

构建图表

图结构是基于关键点集的几何结构构建的。在这个例子中,我们参考了Delaunay三角剖分

def delaunay_triangulation(kpt):
    d = spa.Delaunay(kpt.numpy().transpose())
    A = paddle.zeros((len(kpt[0]), len(kpt[0])))
    for simplex in d.simplices:
        for pair in itertools.permutations(simplex, 2):
            A[pair] = 1
    return A

A1 = delaunay_triangulation(kpts1)
A2 = delaunay_triangulation(kpts2)

我们将边的长度编码为边特征

A1 = ((kpts1.unsqueeze(1) - kpts1.unsqueeze(2)) ** 2).sum(axis=0) * A1
A1 = (A1 / A1.max()).cast(dtype=paddle.float32)
A2 = ((kpts2.unsqueeze(1) - kpts2.unsqueeze(2)) ** 2).sum(axis=0) * A2
A2 = (A2 / A2.max()).cast(dtype=paddle.float32)

可视化图表

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with Graphs')
plot_image_with_graph(img1, kpts1, A1)
plt.subplot(1, 2, 2)
plt.title('Image 2 with Graphs')
plot_image_with_graph(img2, kpts2, A2)
Image 1 with Graphs, Image 2 with Graphs

提取节点特征

让我们采用VGG16 CNN模型来提取节点特征。

vgg16_cnn = vgg16(pretrained=False, batch_norm=True) # no official pretrained paddle weight for vgg16_bn provided yet
path = pygm.utils.download(filename='vgg16_bn.pdparams', \
                           url='https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=11AGmtBrIZJLXJMk4Um9xQPai2EH7KjRY', \
                           md5='cf6079f3c8d16f42a93fc8f8b62e20d1')
vgg16_cnn.set_dict(paddle.load(path))
paddle_img1 = paddle.to_tensor(np.array(img1, dtype=np.float32) / 256).transpose((2, 0, 1)).unsqueeze(0) # shape: BxCxHxW
paddle_img2 = paddle.to_tensor(np.array(img2, dtype=np.float32) / 256).transpose((2, 0, 1)).unsqueeze(0) # shape: BxCxHxW
with paddle.set_grad_enabled(False):
    feat1 = vgg16_cnn.features(paddle_img1)
    feat2 = vgg16_cnn.features(paddle_img2)

标准化特征

num_features = feat1.shape[1]
def l2norm(node_feat):
    return paddle.nn.functional.local_response_norm(
        node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)

feat1 = l2norm(feat1)
feat2 = l2norm(feat2)

将特征上采样到原始图像大小

feat1_upsample = paddle.nn.functional.interpolate(feat1, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_upsample = paddle.nn.functional.interpolate(feat2, (obj_resize[1], obj_resize[0]), mode='bilinear')

可视化提取的CNN特征(通过主成分分析进行降维)

pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
feat_dim_reduc = pca_dim_reduc.fit_transform(
    np.concatenate((
        feat1_upsample.transpose((0, 2, 3, 1)).reshape((-1, num_features)).numpy(),
        feat2_upsample.transpose((0, 2, 3, 1)).reshape((-1, num_features)).numpy()
    ), axis=0)
)
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with CNN features')
plot_image_with_graph(img1, kpts1, A1)
plt.imshow(feat1_dim_reduc.reshape((obj_resize[1], obj_resize[0], 3)), alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Image 2 with CNN features')
plot_image_with_graph(img2, kpts2, A2)
plt.imshow(feat2_dim_reduc.reshape((obj_resize[1], obj_resize[0], 3)), alpha=0.5)
Image 1 with CNN features, Image 2 with CNN features
<matplotlib.image.AxesImage object at 0x7feb905a6ec0>

通过最近邻插值提取节点特征

rounded_kpts1 = paddle.cast(paddle.round(kpts1), dtype='int64')
rounded_kpts2 = paddle.cast(paddle.round(kpts2), dtype='int64')

node1 = feat1_upsample.transpose((2, 3, 0, 1))[rounded_kpts1[1], rounded_kpts1[0]][:, 0]
node2 = feat2_upsample.transpose((2, 3, 0, 1))[rounded_kpts2[1], rounded_kpts2[0]][:, 0]

构建亲和矩阵

我们遵循二次分配问题(QAP)的公式:

\[\begin{split}&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\ s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}\end{split}\]

其中第一步是构建亲和矩阵(\(\mathbf{K}\)

conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1) # set affinity function
K = pygm.utils.build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, edge_aff_fn=gaussian_aff)

亲和矩阵的可视化。对于具有\(N\)个节点的图匹配问题,亲和矩阵有\(N^2\times N^2\)个元素,因为每个图中有\(N^2\)条边。

注意

对角线元素是节点亲和力,非对角线元素是边缘特征。

plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')
Affinity Matrix (size: 100$\times$100)
<matplotlib.image.AxesImage object at 0x7feb906ce0b0>

使用RRWM求解器解决图匹配问题

请参阅rrwm()以获取API参考。

X = pygm.rrwm(K, kpts1.shape[1], kpts2.shape[1])

RRWM的输出是一个软匹配矩阵。然后采用匈牙利算法来达到一个离散的匹配矩阵。

X = pygm.hungarian(X)

绘制匹配

正确的匹配用绿色标记,错误的匹配用红色标记。在这个例子中,节点按照它们的真实类别排序(即真实匹配矩阵是一个对角矩阵)。

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by RRWM

通过其他求解器解决

我们也可以在这个特定问题上对其他求解器进行快速基准测试。

IPFP求解器

请参阅ipfp()以获取API参考。

X = pygm.ipfp(K, kpts1.shape[1], kpts2.shape[1])

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPFP')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by IPFP

SM求解器

请参阅sm()以获取API参考。

X = pygm.sm(K, kpts1.shape[1], kpts2.shape[1])
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by SM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by SM

NGM 求解器

请参阅ngm()的API参考。

注意

NGM求解器是在不同的问题设置上预训练的,因此它们的性能可能看起来较差。 为了提高它们的性能,您可以改变构建亲和矩阵的方式,或者尝试在新的问题上微调NGM。

NGM求解器在Willow数据集上预训练:

X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='willow')
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (willow pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by NGM (willow pretrain)

NGM 求解器在 VOC 数据集上预训练:

X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='voc')
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (voc pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
    j = paddle.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by NGM (voc pretrain)

脚本总运行时间: (0 分钟 10.361 秒)

Gallery generated by Sphinx-Gallery