PyTorch 后端示例:多图匹配

此示例展示了如何匹配多个图。多图匹配意味着同时匹配两个以上的图。

# Author: Zetian Jiang <maple_jzt@sjtu.edu.cn>
#         Ziao Guo <ziao.guo@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

以下求解器包含在此示例中:

import os
import time
import math
import copy
import torch  # pytorch backend
import itertools
import numpy as np
import pygmtools as pygm
import matplotlib.pyplot as plt  # for plotting
import scipy.io as sio  # for loading .mat file
import scipy.spatial as spa  # for Delaunay triangulation

from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.set_backend('pytorch')  # set default backend for pygmtools

加载图片

图像来自Willow对象类数据集(该数据集也可在pygmtools的基准测试中使用,参见WillowObject)。

图像被调整为256x256。

obj_resize = (256, 256)
n_images = 30
n_outlier = 0
img_list = []
kpts_list = []
n_kpts_list = []
perm_list = []

bm = pygm.benchmark.Benchmark(name='WillowObject',
                              sets='train',
                              obj_resize=obj_resize)

while len(img_list) < n_images:
    data_list, gt_dict, _ = bm.rand_get_data(cls='Car')
    for data in data_list:
        img = Image.fromarray(data['img'])
        coords = sorted(data['kpts'], key=lambda x: x['labels'])
        kpts = torch.tensor([[kpt['x'] for kpt in coords],
                        [kpt['y'] for kpt in coords]])
        perm = np.eye(kpts.shape[1])
        img_list.append(img)
        kpts_list.append(kpts)
        n_kpts_list.append(kpts.shape[1])
        perm_list.append(perm)

可视化图像和关键点

def plot_image_with_graph(img, kpt, A=None):
    plt.imshow(img)
    plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
    if A is not None:
        for idx in torch.nonzero(A, as_tuple=False):
            plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')


plt.figure(figsize=(20, 18))
for i in range(n_images):
    plt.subplot(5, n_images // 5, i + 1)
    plt.title('Image {}'.format(i + 1))
    plot_image_with_graph(img_list[i], kpts_list[i])
# plt.savefig('image')
# plt.close()
Image 1, Image 2, Image 3, Image 4, Image 5, Image 6, Image 7, Image 8, Image 9, Image 10, Image 11, Image 12, Image 13, Image 14, Image 15, Image 16, Image 17, Image 18, Image 19, Image 20, Image 21, Image 22, Image 23, Image 24, Image 25, Image 26, Image 27, Image 28, Image 29, Image 30

构建图表

图结构是基于关键点集的几何结构构建的。在这个例子中,我们参考了Delaunay三角剖分

def delaunay_triangulation(kpt):
    d = spa.Delaunay(kpt.numpy().transpose())
    A = torch.zeros(len(kpt[0]), len(kpt[0]))
    for simplex in d.simplices:
        for pair in itertools.permutations(simplex, 2):
            A[pair] = 1
    return A


adj_list = []
for i in range(n_images):
    A = delaunay_triangulation(kpts_list[i])
    adj_list.append(A)

构建亲和矩阵

我们遵循二次分配问题(QAP)的公式:

\[\begin{split}&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\ s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}\end{split}\]

第一步是为每对图构建亲和矩阵(\(\mathbf{K}\)

def get_feature(n, points, adj):
    """
    :param n: points # of graph
    :param points: torch tensor, (n, 2)
    :param adj: torch tensor, (n, n)
    :return: edge feat, angle feat
    """
    points_1 = points.reshape(n, 1, 2).repeat(1, n, 1)
    points_2 = points.reshape(1, n, 2).repeat(n, 1, 1)
    edge_feat = torch.sqrt(torch.sum((points_1 - points_2) ** 2, dim=2))
    edge_feat = edge_feat / torch.max(edge_feat)
    angle_feat = torch.atan((points_1[:, :, 1] - points_2[:, :, 1]) / (points_1[:, :, 0] - points_2[:, :, 0] + 1e-8))
    angle_feat = 2 * angle_feat / math.pi

    return edge_feat, angle_feat


def get_pair_affinity(edge_feat_1, angle_feat_1, edge_feat_2, angle_feat_2, adj1, adj2):
    n1, n2 = edge_feat_1.shape[0], edge_feat_2.shape[0]
    assert n1 == angle_feat_1.shape[0] and n2 == angle_feat_2.shape[0]

    left_adj = adj1.reshape(n1, n1, 1, 1).repeat(1, 1, n2, n2)
    right_adj = adj2.reshape(1, 1, n2, n2).repeat(n1, n1, 1, 1)
    adj = left_adj * right_adj

    left_edge_feat = edge_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
    right_edge_feat = edge_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
    edge_weight = torch.sqrt(torch.sum((left_edge_feat - right_edge_feat) ** 2, dim=-1))

    left_angle_feat = angle_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
    right_angle_feat = angle_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
    angle_weight = torch.sqrt(torch.sum((left_angle_feat - right_angle_feat) ** 2, dim=-1))

    affinity = edge_weight * 0.9 + angle_weight * 0.1
    affinity = torch.exp(-affinity / 0.1) * adj
    affinity = affinity.transpose(1, 2)

    return affinity


def generate_affinity_matrix(n_points, points_list, adj_list):
    m = len(n_points)
    n_max = max(n_points)
    affinity = torch.zeros(m, m, n_max, n_max, n_max, n_max)

    edge_feat_list = []
    angle_feat_list = []
    for n, points, adj in zip(n_points, points_list, adj_list):
        edge_feat, angle_feat = get_feature(n, points, adj)
        edge_feat_list.append(edge_feat)
        angle_feat_list.append(angle_feat)

    for i, j in itertools.product(range(m), range(m)):
        pair_affinity = get_pair_affinity(edge_feat_list[i],
                                          angle_feat_list[i],
                                          edge_feat_list[j],
                                          angle_feat_list[j],
                                          adj_list[i],
                                          adj_list[j])
        affinity[i, j] = pair_affinity

    affinity = affinity.permute(0, 1, 3, 2, 5, 4).reshape(m, m, n_max * n_max, n_max * n_max)
    return affinity


affinity_mat = generate_affinity_matrix(n_kpts_list, kpts_list, adj_list)

m = len(kpts_list)
n = int(torch.max(torch.tensor(n_kpts_list)))
ns_src = torch.ones(m * m).int() * n
ns_tgt = torch.ones(m * m).int() * n

计算准确性、一致性和亲和力

def cal_accuracy(mat, gt_mat, n):
    m = mat.shape[0]
    acc = 0
    for i in range(m):
        for j in range(m):
            _mat, _gt_mat = mat[i, j], gt_mat[i, j]
            row_sum = torch.sum(_gt_mat, dim=0)
            col_sum = torch.sum(_gt_mat, dim=1)
            row_idx = [k for k in range(n) if row_sum[k] != 0]
            col_idx = [k for k in range(n) if col_sum[k] != 0]
            _mat = _mat[row_idx, :]
            _mat = _mat[:, col_idx]
            _gt_mat = _gt_mat[row_idx, :]
            _gt_mat = _gt_mat[:, col_idx]
            acc += 1 - torch.sum(torch.abs(_mat - _gt_mat)) / 2 / (n - n_outlier)
    return acc / (m * m)


def cal_consistency(mat, gt_mat, m, n):
    return torch.mean(get_batch_pc_opt(mat))


def cal_affinity(X, X_gt, K, m, n):
    X_batch = X.reshape(-1, n, n)
    X_gt_batch = X_gt.reshape(-1, n, n)
    K_batch = K.reshape(-1, n * n, n * n)
    affinity = get_batch_affinity(X_batch, K_batch)
    affinity_gt = get_batch_affinity(X_gt_batch, K_batch)
    return torch.mean(affinity / (affinity_gt + 1e-8))


def get_batch_affinity(X, K, norm=1):
    """
    calculate affinity score
    :param X: (b, n, n)
    :param K: (b, n*n, n*n)
    :param norm: normalization term
    :return: affinity_score (b, 1, 1)
    """
    b, n, _ = X.size()
    vx = X.transpose(1, 2).reshape(b, -1, 1)  # (b, n*n, 1)
    vxt = vx.transpose(1, 2)  # (b, 1, n*n)
    affinity = torch.bmm(torch.bmm(vxt, K), vx) / norm
    return affinity


def get_single_affinity(X, K, norm=1):
    """
    calculate affinity score
    :param X: (n, n)
    :param K: (n*n, n*n)
    :param norm: normalization term
    :return: affinity_score scale
    """
    n, _ = X.size()
    vx = X.transpose(0, 1).reshape(-1, 1)
    vxt = vx.transpose(0, 1)
    affinity = torch.matmul(torch.matmul(vxt, K), vx) / norm
    return affinity


def get_single_pc(X, i, j, Xij=None):
    """
    :param X: (m, m, n, n) all the matching results
    :param i: index
    :param j: index
    :param Xij: (n, n) matching
    :return: the consistency of X_ij
    """
    m, _, n, _ = X.size()
    if Xij is None:
        Xij = X[i, j]
    pair_con = 0
    for k in range(m):
        X_combo = torch.matmul(X[i, k], X[k, j])
        pair_con += torch.sum(torch.abs(Xij - X_combo)) / (2 * n)
    return 1 - pair_con / m


def get_single_pc_opt(X, i, j, Xij=None):
    """
    :param X: (m, m, n, n) all the matching results
    :param i: index
    :param j: index
    :return: the consistency of X_ij
    """
    m, _, n, _ = X.size()
    if Xij is None:
        Xij = X[i, j]
    X1 = X[i, :].reshape(-1, n, n)
    X2 = X[:, j].reshape(-1, n, n)
    X_combo = torch.bmm(X1, X2)
    pair_con = 1 - torch.sum(torch.abs(Xij - X_combo)) / (2 * n * m)
    return pair_con


def get_batch_pc(X):
    """
    :param X: (m, m, n, n) all the matching results
    :return: (m, m) the consistency of X
    """
    pair_con = torch.zeros(m, m).cuda()
    for i in range(m):
        for j in range(m):
            pair_con[i, j] = get_single_pc_opt(X, i, j)
    return pair_con


def get_batch_pc_opt(X):
    """
    :param X: (m, m, n, n) all the matching results
    :return: (m, m) the consistency of X
    """
    m, _, n, _ = X.size()
    X1 = X.reshape(m, 1, m, n, n).repeat(1, m, 1, 1, 1).reshape(-1, n, n)  # X1[i, j, k] = X[i, k]
    X2 = X.reshape(1, m, m, n, n).repeat(m, 1, 1, 1, 1).transpose(1, 2).reshape(-1, n, n)  # X2[i, j, k] = X[k, j]
    X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n)
    X_ori = X.reshape(m, m, 1, n, n).repeat(1, 1, m, 1, 1)
    pair_con = 1 - torch.sum(torch.abs(X_combo - X_ori), dim=(2, 3, 4)) / (2 * n * m)
    return pair_con


def eval(mat, gt_mat, affinity, m, n):
    acc = cal_accuracy(mat, gt_mat, n)
    src = cal_affinity(mat, gt_mat, affinity, m, n)
    con = cal_consistency(mat, gt_mat, m, n)
    return acc, src, con

生成gt矩阵

gt_mat = torch.zeros(m, m, n, n)
for i in range(m):
    for j in range(m):
        gt_mat[i, j] = torch.tensor(np.matmul(perm_list[i].transpose(0, 1), perm_list[j]))
# print(perm_list[0])
# print(perm_list[1])
# print(gt_mat[1, 2])
# print(gt_mat[0, 1] - gt_mat[1, 0].transpose(0, 1))

通过RRWM进行成对图匹配

请参阅rrwm()以获取API参考。

a = 0
b = 12
tic = time.time()
rrwm_mat = pygm.classic_solvers.rrwm(affinity_mat.reshape(-1, n * n, n * n), ns_src, ns_tgt)
rrwm_mat = pygm.linear_solvers.hungarian(rrwm_mat)
toc = time.time()
rrwm_mat = rrwm_mat.reshape(m, m, n, n)
rrwm_acc, rrwm_src, rrwm_con = eval(rrwm_mat, gt_mat, affinity_mat, m, n)
rrwm_tim = toc - tic

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = rrwm_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
# plt.savefig("RRWM.png")
# plt.close()
Multi-Graph Matching Result by RRWM

多图匹配通过多图求解器

多图匹配:CAO-M 参见 cao() 以获取API参考。

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_m_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='memory')
cao_m_mat = pygm.linear_solvers.hungarian(cao_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_m_acc, cao_m_src, cao_m_con = eval(cao_m_mat, gt_mat, affinity_mat, m, n)
cao_m_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_m_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
# plt.savefig("CAO-M.png")
# plt.close()
Multi-Graph Matching Result by CAO-M

多图匹配:CAO-T 请参阅cao()以获取API参考。

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_t_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='time')
cao_t_mat = pygm.linear_solvers.hungarian(cao_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_t_acc, cao_t_src, cao_t_con = eval(cao_t_mat, gt_mat, affinity_mat, m, n)
cao_t_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_t_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
# plt.savefig("CAO-T.png")
# plt.close()
Multi-Graph Matching Result by CAO-T

多图匹配:MGM-Floyd-M 参见mgm_floyd()以获取API参考。

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_m_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.4, mode='memory')
floyd_m_mat = pygm.linear_solvers.hungarian(floyd_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_m_acc, floyd_m_src, floyd_m_con = eval(floyd_m_mat, gt_mat, affinity_mat, m, n)
floyd_m_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_m_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
# plt.savefig("Floyd-M.png")
# plt.close()
Multi-Graph Matching Result by Floyd-M

多图匹配:MGM-Floyd-T 请参阅mgm_floyd()以获取API参考。

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_t_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.6, mode='time')
floyd_t_mat = pygm.linear_solvers.hungarian(floyd_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_t_acc, floyd_t_src, floyd_t_con = eval(floyd_t_mat, gt_mat, affinity_mat, m, n)
floyd_t_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_t_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
# plt.savefig("Floyd-T.png")
# plt.close()
Multi-Graph Matching Result by Floyd-T

脚本总运行时间: (0 分钟 22.959 秒)

Gallery generated by Sphinx-Gallery