PyTorch 后端示例:通过图匹配进行模型融合

此示例展示了如何通过pygmtools将不同的模型融合成一个单一模型。 模型融合的目的是将多个模型融合为一个,使得融合后的模型能够具有更高的性能。 神经网络可以被视为图(通道 - 节点,通道之间的更新函数 - 边; 节点特征 - 偏置,边特征 - 权重),而融合模型相当于解决一个图匹配 问题。在此示例中,给定的模型是在来自不同分布的MNIST数据上训练的,融合后的模型可以结合两个输入模型的知识,并在测试时达到更高的准确率。

# Author: Chang Liu <only-changer@sjtu.edu.cn>
#         Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

注意

这是对Liu等人。深度神经网络融合通过图匹配应用于模型集成和联邦学习。ICML 2022.中思想的简化实现。 更多详情,请参阅论文和官方代码库

注意

以下求解器包含在此示例中:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm

pygm.set_backend('pytorch')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

定义一个简单的CNN分类器网络

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=1, padding_mode='replicate', bias=False)
        self.max_pool = nn.MaxPool2d(2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=1, padding_mode='replicate', bias=False)
        self.fc1 = nn.Linear(3136, 32, bias=False)
        self.fc2 = nn.Linear(32, 10, bias=False)

    def forward(self, x):
        output = F.relu(self.conv1(x))
        output = self.max_pool(output)
        output = F.relu(self.conv2(output))
        output = self.max_pool(output)
        output = output.view(output.shape[0], -1)
        output = self.fc1(output)
        output = self.fc2(output)
        return output

加载要融合的训练模型

model1 = SimpleNet()
model2 = SimpleNet()
model1.load_state_dict(torch.load('../data/example_model_fusion_1.dat', map_location=device))
model2.load_state_dict(torch.load('../data/example_model_fusion_2.dat', map_location=device))
model1.to(device)
model2.to(device)
test_dataset = torchvision.datasets.MNIST(
    root='../data/mnist_data',  # the directory to store the dataset
    train=False,  # the dataset is used to test
    transform=transforms.ToTensor(),  # the dataset is in the form of tensors
    download=True)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False)

打印简单CNN模型的层:

print(model1)
SimpleNet(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
  (fc1): Linear(in_features=3136, out_features=32, bias=False)
  (fc2): Linear(in_features=32, out_features=10, bias=False)
)

测试输入模型

with torch.no_grad():
    n_correct1 = 0
    n_correct2 = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs1 = model1(images)
        outputs2 = model2(images)
        _, predictions1 = torch.max(outputs1, 1)
        _, predictions2 = torch.max(outputs2, 1)
        n_samples += labels.shape[0]
        n_correct1 += (predictions1 == labels).sum().item()
        n_correct2 += (predictions2 == labels).sum().item()
    acc1 = 100 * n_correct1 / n_samples
    acc2 = 100 * n_correct2 / n_samples

测试结果(两个独立的模型):

print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.8%

构建图匹配的亲和矩阵

如下图所示,神经网络可以被视为图。权重对应于边的特征,而偏置对应于节点的特征。在这个例子中,神经网络没有偏置,因此只有边的特征。

plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()
plot model fusion pytorch

定义图匹配亲和度度量函数

class Ground_Metric_GM:
    def __init__(self,
                 model_1_param: torch.tensor = None,
                 model_2_param: torch.tensor = None,
                 conv_param: bool = False,
                 bias_param: bool = False,
                 pre_conv_param: bool = False,
                 pre_conv_image_size_squared: int = None):
        self.model_1_param = model_1_param
        self.model_2_param = model_2_param
        self.conv_param = conv_param
        self.bias_param = bias_param
        # bias, or fully-connected from linear
        if bias_param is True or (conv_param is False and pre_conv_param is False):
            self.model_1_param = self.model_1_param.reshape(1, -1, 1)
            self.model_2_param = self.model_2_param.reshape(1, -1, 1)
        # fully-connected from conv
        elif conv_param is False and pre_conv_param is True:
            self.model_1_param = self.model_1_param.reshape(1, -1, pre_conv_image_size_squared)
            self.model_2_param = self.model_2_param.reshape(1, -1, pre_conv_image_size_squared)
        # conv
        else:
            self.model_1_param = self.model_1_param.reshape(1, -1, model_1_param.shape[-1])
            self.model_2_param = self.model_2_param.reshape(1, -1, model_2_param.shape[-1])

    def process_distance(self, p: int = 2):
        return torch.cdist(
            self.model_1_param.to(torch.float),
            self.model_2_param.to(torch.float),
            p=p)[0]

    def process_soft_affinity(self, p: int = 2):
        return torch.exp(0 - self.process_distance(p=p))

定义两个神经网络之间的亲和力函数。该函数接受多个神经网络模块,并构建相应的亲和力矩阵,该矩阵随后由图匹配求解器进一步处理。

def graph_matching_fusion(networks: list):
    def total_node_num(network: torch.nn.Module):
        # count the total number of nodes in the network [network]
        num_nodes = 0
        for idx, (name, parameters) in enumerate(network.named_parameters()):
            if 'bias' in name:
                continue
            if idx == 0:
                num_nodes += parameters.shape[1]
            num_nodes += parameters.shape[0]
        return num_nodes

    n1 = total_node_num(network=networks[0])
    n2 = total_node_num(network=networks[1])
    assert (n1 == n2)
    affinity = torch.zeros([n1 * n2, n1 * n2], device=device)
    num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
    num_nodes_before = 0
    num_nodes_incremental = []
    num_nodes_layers = []
    pre_conv_list = []
    cur_conv_list = []
    conv_kernel_size_list = []
    num_nodes_pre = 0
    is_conv = False
    pre_conv = False
    pre_conv_out_channel = 1
    is_final_bias = False
    perm_is_complete = True
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    for idx, ((_, fc_layer0_weight), (_, fc_layer1_weight)) in \
            enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
        assert fc_layer0_weight.shape == fc_layer1_weight.shape
        layer_shape = fc_layer0_weight.shape
        num_nodes_cur = fc_layer0_weight.shape[0]
        if len(layer_shape) > 1:
            if is_conv is True and len(layer_shape) == 2:
                num_nodes_pre = pre_conv_out_channel
            else:
                num_nodes_pre = fc_layer0_weight.shape[1]
        if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
            pre_bias = True
        else:
            pre_bias = False
        if len(layer_shape) > 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = True
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data.view(
                fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)
            fc_layer1_weight_data = fc_layer1_weight.data.view(
                fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)
        elif len(layer_shape) == 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data
            fc_layer1_weight_data = fc_layer1_weight.data
        else:
            is_bias = True
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.data
            fc_layer1_weight_data = fc_layer1_weight.data
        if is_conv:
            pre_conv_out_channel = num_nodes_cur
        if is_bias is True and idx == num_layers - 1:
            is_final_bias = True
        if idx == 0:
            for a in range(num_nodes_pre):
                affinity[(num_nodes_before + a) * n2 + num_nodes_before + a] \
                    [(num_nodes_before + a) * n2 + num_nodes_before + a] \
                    = 1
        if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
                idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
            for a in range(num_nodes_cur):
                affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
                    [(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
                    = 1
        if is_bias is False:
            ground_metric = Ground_Metric_GM(
                fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
                pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
        else:
            ground_metric = Ground_Metric_GM(
                fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
                pre_conv, 1)

        layer_affinity = ground_metric.process_soft_affinity(p=2)

        if is_bias is False:
            pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
            conv_kernel_size_list.append(pre_conv_kernel_size)
        if is_bias is True and is_final_bias is False:
            for a in range(num_nodes_cur):
                for c in range(num_nodes_cur):
                    affinity[(num_nodes_before + a) * n2 + num_nodes_before + c] \
                        [(num_nodes_before + a) * n2 + num_nodes_before + c] \
                        = layer_affinity[a][c]
        elif is_final_bias is False:
            for a in range(num_nodes_pre):
                for b in range(num_nodes_cur):
                    affinity[
                    (num_nodes_before + a) * n2 + num_nodes_before:
                    (num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
                        = layer_affinity[a + b * num_nodes_pre].view(num_nodes_cur, num_nodes_pre).transpose(0, 1)
        if is_bias is False:
            num_nodes_before += num_nodes_pre
            num_nodes_incremental.append(num_nodes_before)
            num_nodes_layers.append(num_nodes_cur)
    # affinity = (affinity + affinity.t()) / 2
    return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]

获取model1和model2之间的亲和度(相似度)矩阵。

K, params = graph_matching_fusion([model1, model2])

通过图匹配对齐模型

通过图匹配算法最大化亲和度(相似度)来对齐 model1 和 model2 的通道。

n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)

项目 X 到神经网络匹配结果。神经网络匹配矩阵是通过将匈牙利算法应用于 X 的小块来构建的,因为只有来自同一神经网络层的通道才能匹配。

注意

在这个例子中,我们假设最后一个FC层是对齐的,不需要匹配。

new_X = torch.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = torch.eye(params[2][0], device=device)
for start_idx, length in zip(params[2][:-1], params[3][:-1]):  # params[2] and params[3] are the indices of layers
    slicing = slice(start_idx, start_idx + length)
    new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = torch.eye(params[3][-1], device=device)
X = new_X

匹配结果的可视化。黑线分隔了不同层的通道。

plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
    plt.axvline(x=idx, color='k')
    plt.axhline(y=idx, color='k')
plot model fusion pytorch

定义对齐函数:基于匹配结果融合模型

def align(solution, fusion_proportion, networks: list, params: list):
    [_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    aligned_wt_0 = [parameter.data for name, parameter in named_weight_list_0]
    idx = 0
    num_layers = len(aligned_wt_0)
    for num_before, num_cur, cur_conv, cur_kernel_size in \
            zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
        perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
        assert 'bias' not in named_weight_list_0[idx][0]
        if len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (perm.transpose(0, 1).to(torch.float64) @
                                 aligned_wt_0[idx].to(torch.float64).permute(2, 3, 0, 1)) \
                .permute(2, 3, 0, 1)
        else:
            aligned_wt_0[idx] = perm.transpose(0, 1).to(torch.float64) @ aligned_wt_0[idx].to(torch.float64)
        idx += 1
        if idx >= num_layers:
            continue
        if 'bias' in named_weight_list_0[idx][0]:
            aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
            idx += 1
        if idx >= num_layers:
            continue
        if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
            aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
                                 .reshape(aligned_wt_0[idx].shape[0], 64, -1)
                                 .permute(0, 2, 1)
                                 @ perm.to(torch.float64)) \
                .permute(0, 2, 1) \
                .reshape(aligned_wt_0[idx].shape[0], -1)
        elif len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
                                 .permute(2, 3, 0, 1)
                                 @ perm.to(torch.float64)) \
                .permute(2, 3, 0, 1)
        else:
            aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
    assert idx == num_layers

    averaged_weights = []
    for idx, parameter in enumerate(networks[1].parameters()):
        averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx] + fusion_proportion * parameter)
    return averaged_weights

测试融合模型

fusion_proportion 变量表示对新模型的贡献。例如,如果 fusion_proportion=0.2, 融合模型 = 80% model1 + 20% model2。

def align_model_and_test(X):
    acc_list = []
    for fusion_proportion in torch.arange(0, 1.1, 0.1):
        fused_weights = align(X, fusion_proportion, [model1, model2], params)

        fused_model = SimpleNet()
        state_dict = fused_model.state_dict()
        for idx, (key, _) in enumerate(state_dict.items()):
            state_dict[key] = fused_weights[idx]
        fused_model.load_state_dict(state_dict)
        fused_model.to(device)
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            output = fused_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        acc = 100. * correct / len(test_loader.dataset)
        print(
            f"{1 - fusion_proportion:.2f} model1 + {fusion_proportion:.2f} model2 -> fused model accuracy: {acc:.2f}%")
        acc_list.append(acc)
    return torch.tensor(acc_list)


print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.46%
0.80 model1 + 0.20 model2 -> fused model accuracy: 86.92%
0.70 model1 + 0.30 model2 -> fused model accuracy: 88.38%
0.60 model1 + 0.40 model2 -> fused model accuracy: 86.43%
0.50 model1 + 0.50 model2 -> fused model accuracy: 74.11%
0.40 model1 + 0.60 model2 -> fused model accuracy: 72.45%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.12%
0.20 model1 + 0.80 model2 -> fused model accuracy: 81.65%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.29%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80%

与普通模型融合(无匹配)相比,图匹配方法稳定了融合步骤:

print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(torch.eye(n1, device=device))

plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
plt.show()
Fused Model Accuracy
No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.15%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.36%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.87%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.55%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80%