Jittor 后端示例:通过图匹配进行模型融合

此示例展示了如何通过pygmtools将不同的模型融合成一个单一模型。 模型融合的目的是将多个模型融合为一个,使得融合后的模型能够具有更高的性能。 神经网络可以被视为图(通道 - 节点,通道之间的更新函数 - 边; 节点特征 - 偏置,边特征 - 权重),而融合模型相当于解决一个图匹配 问题。在此示例中,给定的模型是在来自不同分布的MNIST数据上训练的,融合后的模型可以结合两个输入模型的知识,并在测试时达到更高的准确率。

# Author: Chang Liu <only-changer@sjtu.edu.cn>
#         Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

这是对Liu等人。深度神经网络融合通过图匹配应用于模型集成和联邦学习。ICML 2022.中思想的简化实现。 更多详情,请参阅论文和官方代码库

注意

以下求解器包含在此示例中:

import jittor as jt
import jittor.nn as nn
from jittor import Module, transform, dataset
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm

pygm.set_backend('jittor')

定义一个简单的CNN分类器网络

class SimpleNet(Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # self.pad = nn.ReplicationPad2d(1)
        self.conv1 = nn.Conv2d(1, 32, 5, bias=False)
        self.max_pool = nn.MaxPool2d(2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=False)
        self.fc1 = nn.Linear(3136, 32, bias=False)
        self.fc2 = nn.Linear(32, 10, bias=False)

    def execute(self, x):
        output = nn.relu(self.conv1(nn.ReplicationPad2d(1)(x)))
        output = self.max_pool(output)
        output = nn.relu(self.conv2(nn.ReplicationPad2d(1)(output)))
        output = self.max_pool(output)
        output = output.view(output.shape[0], -1)
        output = self.fc1(output)
        output = self.fc2(output)
        return output

加载要融合的训练模型

model1 = SimpleNet()
model2 = SimpleNet()
model1.load_state_dict(jt.load('../data/example_model_fusion_1_jittor.dat'))
model2.load_state_dict(jt.load('../data/example_model_fusion_2_jittor.dat'))
test_transform = transform.Compose([
    transform.Gray(),
    transform.ToTensor()
])
test_dataset = dataset.MNIST(
    data_root='../data/mnist_data_jittor/',  # the directory to store the dataset
    train=False,  # the dataset is used to test
    transform=test_transform,  # the dataset is in the form of tensors
    download=True)
test_loader = test_dataset.set_attrs(batch_size=32, shuffle=False)

打印简单CNN模型的层:

print(model1)
SimpleNet(
    conv1: Conv(1, 32, (5, 5), (1, 1), (0, 0), (1, 1), 1, None, None, Kw=None, fan=None, i=None, bound=None)
    max_pool: MaxPool2d(
        _layer: Pool((2, 2), (2, 2), padding=(1, 1), dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op=maximum)
    )
    conv2: Conv(32, 64, (5, 5), (1, 1), (0, 0), (1, 1), 1, None, None, Kw=None, fan=None, i=None, bound=None)
    fc1: Linear(3136, 32, None, None)
    fc2: Linear(32, 10, None, None)
)

测试输入模型

with jt.no_grad():
    n_correct1 = 0
    n_correct2 = 0
    n_samples = 0
    for images, labels in test_loader:
        outputs1 = model1(images)
        outputs2 = model2(images)
        predictions1, _ = jt.argmax(outputs1, 1)
        predictions2, _ = jt.argmax(outputs2, 1)
        n_samples += labels.shape[0]
        n_correct1 += (predictions1 == labels).sum().item()
        n_correct2 += (predictions2 == labels).sum().item()
    acc1 = 100 * n_correct1 / n_samples
    acc2 = 100 * n_correct2 / n_samples

测试结果(两个独立的模型):

print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.81%

构建图匹配的亲和矩阵

如下图所示,神经网络可以被视为图。权重对应于边的特征,而偏置对应于节点的特征。在这个例子中,神经网络没有偏置,因此只有边的特征。

plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()
plot model fusion jittor

定义图匹配亲和度度量函数

class Ground_Metric_GM:
    def __init__(self,
                 model_1_param: jt.Var = None,
                 model_2_param: jt.Var = None,
                 conv_param: bool = False,
                 bias_param: bool = False,
                 pre_conv_param: bool = False,
                 pre_conv_image_size_squared: int = None):
        self.model_1_param = model_1_param
        self.model_2_param = model_2_param
        self.conv_param = conv_param
        self.bias_param = bias_param
        # bias, or fully-connected from linear
        if bias_param is True or (conv_param is False and pre_conv_param is False):
            self.model_1_param = self.model_1_param.reshape(1, -1, 1)
            self.model_2_param = self.model_2_param.reshape(1, -1, 1)
        # fully-connected from conv
        elif conv_param is False and pre_conv_param is True:
            self.model_1_param = self.model_1_param.reshape(1, -1, pre_conv_image_size_squared)
            self.model_2_param = self.model_2_param.reshape(1, -1, pre_conv_image_size_squared)
        # conv
        else:
            self.model_1_param = self.model_1_param.reshape(1, -1, model_1_param.shape[-1])
            self.model_2_param = self.model_2_param.reshape(1, -1, model_2_param.shape[-1])

    def _cdist(self, x1, x2, p = 2):
        assert len(x1.shape) == len(x2.shape) == 3 and x1.shape[-1] == x2.shape[-1]  # actual shape here: [1, 32, 25]
        dist = jt.zeros((x1.shape[0], x1.shape[1], x2.shape[1]))
        assert p == 2
        for i in range(x1.shape[0]):
            for j in range(x1.shape[1]):
                for k in range(x2.shape[1]):
                    dist[i, j, k] = (x1[i][j] - x2[i][k]).pow(2).sum().sqrt().item()
        return dist

    def process_distance(self, p: int = 2):
        return self._cdist(
            self.model_1_param.float(),
            self.model_2_param.float(),
            p=p)[0]

    def process_soft_affinity(self, p: int = 2):
        return jt.exp(0 - self.process_distance(p=p))

定义两个神经网络之间的亲和力函数。该函数接受多个神经网络模块,并构建相应的亲和力矩阵,该矩阵随后由图匹配求解器进一步处理。

def graph_matching_fusion(networks: list):
    def total_node_num(network: jt.nn.Module):
        # count the total number of nodes in the network [network]
        num_nodes = 0
        for idx, (name, parameters) in enumerate(network.named_parameters()):
            if 'bias' in name:
                continue
            if idx == 0:
                num_nodes += parameters.shape[1]
            num_nodes += parameters.shape[0]
        return num_nodes

    n1 = total_node_num(network=networks[0])
    n2 = total_node_num(network=networks[1])
    assert (n1 == n2)
    affinity = jt.zeros([n1 * n2, n1 * n2])
    num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
    num_nodes_before = 0
    num_nodes_incremental = []
    num_nodes_layers = []
    pre_conv_list = []
    cur_conv_list = []
    conv_kernel_size_list = []
    num_nodes_pre = 0
    is_conv = False
    pre_conv = False
    pre_conv_out_channel = 1
    is_final_bias = False
    perm_is_complete = True
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    for idx, ((_, fc_layer0_weight), (_, fc_layer1_weight)) in \
            enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
        assert fc_layer0_weight.shape == fc_layer1_weight.shape
        layer_shape = fc_layer0_weight.shape
        num_nodes_cur = fc_layer0_weight.shape[0]
        if len(layer_shape) > 1:
            if is_conv is True and len(layer_shape) == 2:
                num_nodes_pre = pre_conv_out_channel
            else:
                num_nodes_pre = fc_layer0_weight.shape[1]
        if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
            pre_bias = True
        else:
            pre_bias = False
        if len(layer_shape) > 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = True
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data.reshape(fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)
            fc_layer1_weight_data = fc_layer1_weight.data.reshape(fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)
        elif len(layer_shape) == 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data
            fc_layer1_weight_data = fc_layer1_weight.data
        else:
            is_bias = True
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data
            fc_layer1_weight_data = fc_layer1_weight.data
        if is_conv:
            pre_conv_out_channel = num_nodes_cur
        if is_bias is True and idx == num_layers - 1:
            is_final_bias = True
        if idx == 0:
            for a in range(num_nodes_pre):
                affinity[(num_nodes_before + a) * n2 + num_nodes_before + a, \
                    (num_nodes_before + a) * n2 + num_nodes_before + a] \
                    = 1
        if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
                idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
            for a in range(num_nodes_cur):
                affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a, \
                    (num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
                    = 1
        if is_bias is False:
            ground_metric = Ground_Metric_GM(
                jt.Var(fc_layer0_weight_data), jt.Var(fc_layer1_weight_data), is_conv, is_bias,
                pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
        else:
            ground_metric = Ground_Metric_GM(
                jt.Var(fc_layer0_weight_data), jt.Var(fc_layer1_weight_data), is_conv, is_bias,
                pre_conv, 1)

        layer_affinity = ground_metric.process_soft_affinity(p=2)

        if is_bias is False:
            pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
            conv_kernel_size_list.append(pre_conv_kernel_size)
        if is_bias is True and is_final_bias is False:
            for a in range(num_nodes_cur):
                for c in range(num_nodes_cur):
                    affinity[(num_nodes_before + a) * n2 + num_nodes_before + c, \
                        (num_nodes_before + a) * n2 + num_nodes_before + c] \
                        = layer_affinity[a][c]
        elif is_final_bias is False:
            for a in range(num_nodes_pre):
                for b in range(num_nodes_cur):
                    affinity[
                    (num_nodes_before + a) * n2 + num_nodes_before:
                    (num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
                        = layer_affinity[a + b * num_nodes_pre].view(num_nodes_cur, num_nodes_pre).transpose(0, 1)
        if is_bias is False:
            num_nodes_before += num_nodes_pre
            num_nodes_incremental.append(num_nodes_before)
            num_nodes_layers.append(num_nodes_cur)
    # affinity = (affinity + affinity.t()) / 2
    return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]

获取model1和model2之间的亲和度(相似度)矩阵。

K, params = graph_matching_fusion([model1, model2])

通过图匹配对齐模型

通过图匹配算法最大化亲和度(相似度)来对齐 model1 和 model2 的通道。

n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)

项目 X 到神经网络匹配结果。神经网络匹配矩阵是通过将匈牙利算法应用于 X 的小块来构建的,因为只有来自同一神经网络层的通道才能匹配。

注意

在这个例子中,我们假设最后一个FC层是对齐的,不需要匹配。

new_X = jt.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = jt.init.eye(params[2][0])
for start_idx, length in zip(params[2][:-1], params[3][:-1]):  # params[2] and params[3] are the indices of layers
    slicing = slice(start_idx, start_idx + length)
    new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = jt.init.eye(params[3][-1])
X = new_X

匹配结果的可视化。黑线分隔了不同层的通道。

plt.figure(figsize=(4, 4))
plt.imshow(X.numpy(), cmap='Blues')
for idx in params[2]:
    plt.axvline(x=idx, color='k')
    plt.axhline(y=idx, color='k')
plot model fusion jittor

定义对齐函数:基于匹配结果融合模型

def align(solution, fusion_proportion, networks: list, params: list):
    [_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    aligned_wt_0 = [parameter.data for name, parameter in named_weight_list_0]
    idx = 0
    num_layers = len(aligned_wt_0)
    for num_before, num_cur, cur_conv, cur_kernel_size in \
            zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
        perm = jt.Var(solution[num_before:num_before + num_cur, num_before:num_before + num_cur])
        assert 'bias' not in named_weight_list_0[idx][0]
        if len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (perm.transpose(0, 1).float32() @
                                 jt.Var(aligned_wt_0[idx]).float32().permute(2, 3, 0, 1)) \
                .permute(2, 3, 0, 1)
        else:
            aligned_wt_0[idx] = perm.transpose(0, 1).float32() @ jt.Var(aligned_wt_0[idx]).float32()
        idx += 1
        if idx >= num_layers:
            continue
        if 'bias' in named_weight_list_0[idx][0]:
            aligned_wt_0[idx] = jt.Var(aligned_wt_0[idx]).float32() @ perm.float32()
            idx += 1
        if idx >= num_layers:
            continue
        if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
            aligned_wt_0[idx] = (jt.Var(aligned_wt_0[idx]).float32()
                                 .reshape(aligned_wt_0[idx].shape[0], 64, -1)
                                 .permute(0, 2, 1)
                                 @ perm.float32()) \
                .permute(0, 2, 1) \
                .reshape(aligned_wt_0[idx].shape[0], -1)
        elif len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (jt.Var(aligned_wt_0[idx]).float32()
                                 .permute(2, 3, 0, 1)
                                 @ perm.float32()) \
                .permute(2, 3, 0, 1)
        else:
            aligned_wt_0[idx] = jt.Var(aligned_wt_0[idx]).float32() @ perm.float32()
    assert idx == num_layers

    averaged_weights = []
    for idx, parameter in enumerate(networks[1].parameters()):
        averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx] + fusion_proportion * parameter)
    return averaged_weights

测试融合模型

fusion_proportion 变量表示对新模型的贡献。例如,如果 fusion_proportion=0.2, 融合模型 = 80% model1 + 20% model2。

def align_model_and_test(X):
    acc_list = []
    for fusion_proportion in jt.arange(0, 1.1, 0.1):
        fused_weights = align(X, fusion_proportion, [model1, model2], params)
        fused_model = SimpleNet()
        state_dict = fused_model.state_dict()
        for idx, (key, _) in enumerate(state_dict.items()):
            state_dict[key] = fused_weights[idx]
        fused_model.load_state_dict(state_dict)
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            output = fused_model(data)
            test_loss += nn.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(1, keepdims=True)[0]
            correct += pred.equal(target.view_as(pred)).sum()
        test_loss /= test_dataset.total_len
        acc = (100. * correct / test_dataset.total_len).item()
        print(
            f"{1 - fusion_proportion:.2f} model1 + {fusion_proportion:.2f} model2 -> fused model accuracy: {acc:.2f}%")
        acc_list.append(acc)
    return jt.Var(acc_list)


print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12%
0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21%
0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52%
0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11%
0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74%
0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51%
0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%

与普通模型融合(无匹配)相比,图匹配方法稳定了融合步骤:

print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(jt.init.eye(n1))

plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot(jt.arange(0, 1.1, 0.1).numpy(), gm_acc_list.numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot(jt.arange(0, 1.1, 0.1).numpy(), vanilla_acc_list.numpy(), 'b*-', label='No Matching Fusion')
plt.plot(jt.arange(0, 1.1, 0.1).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot(jt.arange(0, 1.1, 0.1).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
plt.show()
Fused Model Accuracy
No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%