Jittor 后端示例:通过图匹配神经网络匹配图像关键点

此示例展示了如何通过基于神经网络的图匹配求解器来匹配图像关键点。 这些图匹配求解器旨在匹配两个单独的图。匹配后的图像 可以进一步传递以处理下游任务。

# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

以下求解器基于匹配两个单独的图,并包含在此示例中:

import jittor as jt # jittor backend
from jittor import Var, models, nn
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
import numpy as np
from PIL import Image
pygm.set_backend('jittor') # set default backend for pygmtools

通过图匹配神经网络预测匹配

在本节中,我们将展示如何通过图匹配神经网络进行预测(推理)。 让我们以PCA-GM(pca_gm())为例。

加载图片

图像来自Willow对象类数据集(该数据集也可在pygmtools的基准测试中使用,参见WillowObject)。

图像被调整为256x256。

obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = jt.Var(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = jt.Var(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
jittor_img1 = jt.Var(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
jittor_img2 = jt.Var(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW

可视化图像和关键点

def plot_image_with_graph(img, kpt, A=None):
    plt.imshow(img)
    plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
    if A is not None:
        for idx in jt.nonzero(A):
            plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)
Image 1, Image 2

构建图表

图结构是基于关键点集的几何结构构建的。在这个例子中,我们参考了Delaunay三角剖分

def delaunay_triangulation(kpt):
    d = spa.Delaunay(kpt.numpy().transpose())
    A = jt.zeros((len(kpt[0]), len(kpt[0])))
    for simplex in d.simplices:
        for pair in itertools.permutations(simplex, 2):
            A[pair] = 1
    return A

A1 = delaunay_triangulation(kpts1)
A2 = delaunay_triangulation(kpts2)

可视化图表

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with Graphs')
plot_image_with_graph(img1, kpts1, A1)
plt.subplot(1, 2, 2)
plt.title('Image 2 with Graphs')
plot_image_with_graph(img2, kpts2, A2)
Image 1 with Graphs, Image 2 with Graphs

通过CNN提取节点特征

深度图匹配求解器可以与CNN特征提取器融合,以构建端到端的学习管道。

在这个例子中,我们采用基于匹配两个独立图的深度图求解器。 图像特征基于VGG16 CNN模型的两个中间层,遵循现有的深度图匹配论文(如pca_gm()

首先,我们获取并下载VGG16模型:

vgg16_cnn = models.vgg16_bn(True)

VGG16的层列表:

print(vgg16_cnn.features)
Sequential(
    0: Conv(3, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[64,], None, Kw=None, fan=None, i=None, bound=None)
    1: BatchNorm(64, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    2: relu()
    3: Conv(64, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[64,], None, Kw=None, fan=None, i=None, bound=None)
    4: BatchNorm(64, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    5: relu()
    6: Pool((2, 2), (2, 2), padding=(0, 0), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    7: Conv(64, 128, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[128,], None, Kw=None, fan=None, i=None, bound=None)
    8: BatchNorm(128, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    9: relu()
    10: Conv(128, 128, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[128,], None, Kw=None, fan=None, i=None, bound=None)
    11: BatchNorm(128, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    12: relu()
    13: Pool((2, 2), (2, 2), padding=(0, 0), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    14: Conv(128, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    15: BatchNorm(256, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    16: relu()
    17: Conv(256, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    18: BatchNorm(256, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    19: relu()
    20: Conv(256, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    21: BatchNorm(256, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    22: relu()
    23: Pool((2, 2), (2, 2), padding=(0, 0), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    24: Conv(256, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    25: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    26: relu()
    27: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    28: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    29: relu()
    30: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    31: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    32: relu()
    33: Pool((2, 2), (2, 2), padding=(0, 0), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    34: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    35: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    36: relu()
    37: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    38: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    39: relu()
    40: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    41: BatchNorm(512, 1e-05, momentum=0.1, affine=True, is_train=True, sync=True)
    42: relu()
    43: Pool((2, 2), (2, 2), padding=(0, 0), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
)

让我们定义CNN特征提取器,它输出layer (30)layer (37)的特征

class CNNNet(jt.nn.Module):
    def __init__(self, vgg16_module):
        super(CNNNet, self).__init__()
        # The naming of the layers follow ThinkMatch convention to load pretrained models.
        self.node_layers = jt.nn.Sequential(*[_ for _ in list(vgg16_module.features)[:31]])
        self.edge_layers = jt.nn.Sequential(*[_ for _ in list(vgg16_module.features)[31:38]])

    def execute(self, inp_img):
        feat_local = self.node_layers(inp_img)
        feat_global = self.edge_layers(feat_local)
        return feat_local, feat_global

下载预训练的CNN权重(来自ThinkMatch),加载权重然后提取CNN特征

cnn = CNNNet(vgg16_cnn)
path = pygm.utils.download('vgg16_pca_voc_jittor.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1qLxjcVq7X3brylxRJvELCbtCzfuXQ24J')
cnn.load_state_dict(jt.load(path))

with jt.no_grad():
    feat1_local, feat1_global = cnn(jittor_img1)
    feat2_local, feat2_global = cnn(jittor_img2)

标准化特征

def local_response_norm(input: Var, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Var:
    """
    jittor implementation of local_response_norm
    """
    dim = input.ndim
    assert dim >= 3

    if input.numel() == 0:
        return input

    div = input.multiply(input).unsqueeze(1)
    if dim == 3:
        div = nn.pad(div, (0, 0, size // 2, (size - 1) // 2))
        div = nn.avg_pool2d(div, (size, 1), stride=1).squeeze(1)
    else:
        sizes = input.size()
        div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
        div = nn.pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
        div = nn.AvgPool3d((size, 1, 1), stride=1)(div).squeeze(1)
        div = div.view(sizes)
    div = div.multiply(alpha).add(k).pow(beta)
    return input / div


def l2norm(node_feat):
    return local_response_norm(
        node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)

feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)

将特征上采样到原始图像大小并连接

feat1_local_upsample = jt.nn.interpolate(feat1_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_global_upsample = jt.nn.interpolate(feat1_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_local_upsample = jt.nn.interpolate(feat2_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_global_upsample = jt.nn.interpolate(feat2_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_upsample = jt.concat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = jt.concat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]

可视化提取的CNN特征(通过主成分分析进行降维)

pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
feat_dim_reduc = pca_dim_reduc.fit_transform(
    np.concatenate((
        feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
        feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
    ), axis=0)
)
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with CNN features')
plot_image_with_graph(img1, kpts1, A1)
plt.imshow(feat1_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Image 2 with CNN features')
plot_image_with_graph(img2, kpts2, A2)
plt.imshow(feat2_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
Image 1 with CNN features, Image 2 with CNN features
<matplotlib.image.AxesImage object at 0x7f226c614e80>

通过最近邻插值提取节点特征

rounded_kpts1 = jt.round(kpts1).long()
rounded_kpts2 = jt.round(kpts2).long()
node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t() # shape: NxC

调用PCA-GM匹配模型

请参阅pca_gm()以获取API参考。

X = pygm.pca_gm(node1, node2, A1, A2, pretrain='voc')
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by PCA-GM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
idx, _ = jt.argmax(X, dim=1)
for i in range(X.shape[0]):
    j = idx[i].item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by PCA-GM

使用其他神经网络匹配图像

上述管道也适用于其他深度图匹配网络。这里我们给出ipca_gm()cie()的示例。

通过IPCA-GM模型进行匹配

请参阅ipca_gm()以获取API参考。

path = pygm.utils.download('vgg16_ipca_voc_jittor.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1f7KEl9ZFZwI26j6UId-fsdl8Y8QWPKZi')
cnn.load_state_dict(jt.load(path))

feat1_local, feat1_global = cnn(jittor_img1)
feat2_local, feat2_global = cnn(jittor_img2)

标准化特征

def l2norm(node_feat):
    return local_response_norm(
        node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)

feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)

将特征上采样到原始图像大小并连接

feat1_local_upsample = jt.nn.interpolate(feat1_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_global_upsample = jt.nn.interpolate(feat1_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_local_upsample = jt.nn.interpolate(feat2_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_global_upsample = jt.nn.interpolate(feat2_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_upsample = jt.concat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = jt.concat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]

通过最近邻插值提取节点特征

rounded_kpts1 = jt.round(kpts1).long()
rounded_kpts2 = jt.round(kpts2).long()
node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t() # shape: NxC

构建边缘特征作为边缘长度

kpts1_dis = (kpts1.unsqueeze(0) - kpts1.unsqueeze(1))
kpts1_dis = jt.norm(kpts1_dis, p=2, dim=2).detach()
kpts2_dis = (kpts2.unsqueeze(0) - kpts2.unsqueeze(1))
kpts2_dis = jt.norm(kpts2_dis, p=2, dim=2).detach()

Q1 = jt.exp(-kpts1_dis / obj_resize[0])
Q2 = jt.exp(-kpts2_dis / obj_resize[0])

通过IPCA-GM模型进行匹配

X = pygm.ipca_gm(node1, node2, A1, A2, pretrain='voc')
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPCA-GM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
idx, _ = jt.argmax(X, dim=1)
for i in range(X.shape[0]):
    j = idx[i].item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by IPCA-GM

通过CIE模型进行匹配

请参阅cie()的API参考。

path = pygm.utils.download('vgg16_cie_voc_jittor.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1wDbA-8sK4BNhA48z2c-Gtdd4AarRxfqT')
cnn.load_state_dict(jt.load(path))

feat1_local, feat1_global = cnn(jittor_img1)
feat2_local, feat2_global = cnn(jittor_img2)

标准化特征

def l2norm(node_feat):
    return local_response_norm(
        node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)

feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)

将特征上采样到原始图像大小并连接

feat1_local_upsample = jt.nn.interpolate(feat1_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_global_upsample = jt.nn.interpolate(feat1_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_local_upsample = jt.nn.interpolate(feat2_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat2_global_upsample = jt.nn.interpolate(feat2_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
feat1_upsample = jt.concat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = jt.concat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]

通过最近邻插值提取节点特征

rounded_kpts1 = jt.round(kpts1).long()
rounded_kpts2 = jt.round(kpts2).long()
node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t() # shape: NxC

构建边缘特征作为边缘长度

kpts1_dis = (kpts1.unsqueeze(1) - kpts1.unsqueeze(2))
kpts1_dis = jt.norm(kpts1_dis, p=2, dim=0).detach()
kpts2_dis = (kpts2.unsqueeze(1) - kpts2.unsqueeze(2))
kpts2_dis = jt.norm(kpts2_dis, p=2, dim=0).detach()

Q1 = jt.exp(-kpts1_dis / obj_resize[0]).unsqueeze(-1).float32()
Q2 = jt.exp(-kpts2_dis / obj_resize[0]).unsqueeze(-1).float32()

调用CIE匹配模型

X = pygm.cie(node1, node2, A1, A2, Q1, Q2, pretrain='voc')
X = pygm.hungarian(X)

plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by CIE')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
idx, _ = jt.argmax(X, dim=1)
for i in range(X.shape[0]):
    j = idx[i].item()
    con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
    plt.gca().add_artist(con)
Image Matching Result by CIE

训练深度图匹配模型

在本节中,我们展示了如何构建一个支持端到端训练的深度图匹配模型。 对于这里考虑的图像匹配问题,该模型由CNN特征提取器和 一个可学习的匹配模块组成。以PCA-GM模型为例。

注意

这个简单的示例旨在向您展示如何在训练端到端深度图匹配神经网络时执行基本的前向传播和反向传播。一个“更正式”的深度学习管道应该包括异步数据加载器、批处理操作、CUDA支持等,这些都在考虑简化的情况下被省略了。您可以参考ThinkMatch,这是一个包含所有这些高级功能的研究协议。

首先让我们定义神经网络模型。通过调用get_network(),它将简单地返回网络对象。

class GMNet(jt.nn.Module):
    def __init__(self):
        super(GMNet, self).__init__()
        self.gm_net = pygm.utils.get_network(pygm.pca_gm, pretrain=False) # fetch the network object
        self.cnn = CNNNet(vgg16_cnn)

    def execute(self, img1, img2, kpts1, kpts2, A1, A2):
        # CNN feature extractor layers
        feat1_local, feat1_global = self.cnn(img1)
        feat2_local, feat2_global = self.cnn(img2)
        feat1_local = l2norm(feat1_local)
        feat1_global = l2norm(feat1_global)
        feat2_local = l2norm(feat2_local)
        feat2_global = l2norm(feat2_global)

        # upsample feature map
        feat1_local_upsample = jt.nn.interpolate(feat1_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
        feat1_global_upsample = jt.nn.interpolate(feat1_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
        feat2_local_upsample = jt.nn.interpolate(feat2_local, (obj_resize[1], obj_resize[0]), mode='bilinear')
        feat2_global_upsample = jt.nn.interpolate(feat2_global, (obj_resize[1], obj_resize[0]), mode='bilinear')
        feat1_upsample = jt.concat((feat1_local_upsample, feat1_global_upsample), dim=1)
        feat2_upsample = jt.concat((feat2_local_upsample, feat2_global_upsample), dim=1)

        # assign node features
        rounded_kpts1 = jt.round(kpts1).long()
        rounded_kpts2 = jt.round(kpts2).long()
        node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t()  # shape: NxC
        node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t()  # shape: NxC

        # PCA-GM matching layers
        X = pygm.pca_gm(node1, node2, A1, A2, network=self.gm_net) # the network object is reused
        return X

model = GMNet()

定义优化器

optim = jt.optim.Adam(model.parameters(), lr=1e-3)

前向传播

X = model(jittor_img1, jittor_img2, kpts1, kpts2, A1, A2)

计算损失

在这个例子中,真实匹配矩阵是一个对角矩阵。我们通过permutation_loss()计算损失函数。

X_gt = jt.init.eye(X.shape[0])
loss = pygm.utils.permutation_loss(X, X_gt)
print(f'loss={loss:.4f}')
loss=2.9790

反向传播

optim.backward(loss)

可视化梯度

plt.figure(figsize=(4, 4))
plt.title('Gradient Sizes of PCA-GM and VGG16 layers')
plt.gca().set_xlabel('Layer Index')
plt.gca().set_ylabel('Average Gradient Size')
grad_size = []
for param in model.parameters():
    grad_size.append(jt.abs(param.opt_grad(optim)).mean().item())
print(grad_size)
plt.stem(grad_size)
Gradient Sizes of PCA-GM and VGG16 layers
[0.00012032000813633204, 0.0029908756259828806, 0.0001890812418423593, 0.003432080615311861, 0.0002175412664655596, 0.005439706612378359, 8.84531982592307e-06, 4.3844393076142296e-05, 9.275232150685042e-05, 0.0038394334260374308, 0.00014068347809370607, 0.0032387105748057365, 0.0004730010114144534, 2.232744122920849e-08, 0.0010210294276475906, 0.0005979313282296062, 0.0, 0.0, 0.0001719535794109106, 1.1543310307615684e-08, 0.0021903959568589926, 0.0012790058972314, 0.0, 0.0, 0.00025605762493796647, 3.661839054203142e-09, 0.0013879203470423818, 0.001030159299261868, 0.0, 0.0, 0.0002153523819288239, 4.659028274289767e-09, 0.0020632531959563494, 0.0009495018748566508, 0.0, 0.0, 0.00020805593521799892, 1.3135081911030966e-09, 0.0017076103249564767, 0.0011711850529536605, 0.0, 0.0, 0.00017217609274666756, 2.7571318561570024e-09, 0.001730420277453959, 0.0011667078360915184, 0.0, 0.0, 0.00018079271831084043, 3.229687406403059e-09, 0.0021298921201378107, 0.001029890263453126, 0.0, 0.0, 0.0001582755212439224, 7.906685861591711e-10, 0.0017488845624029636, 0.0010780903976410627, 0.0, 0.0, 0.0001245360035682097, 1.4797246761233396e-09, 0.00206647627055645, 0.0012257269117981195, 0.0, 0.0, 0.00012214107846375555, 0.0005140944267623127, 0.001602952484972775, 0.0008278230670839548, 0.0, 0.0, 9.437291737413034e-05, 4.0557557312581594e-10, 0.0014863943215459585, 0.0010152073809877038, 0.0, 0.0, 8.532514766557142e-05, 0.000895310309715569]

<StemContainer object of 3 artists>

更新模型参数。深度学习管道应迭代前向传播和后向传播步骤,直到收敛。

optim.step()
optim.zero_grad()

注意

此示例支持GPU和CPU,在线文档是由仅支持CPU的机器构建的。 如果您在GPU上运行此代码,效率将显著提高。

脚本总运行时间: (1 分钟 12.772 秒)

Gallery generated by Sphinx-Gallery