注意
转到末尾 以下载完整的示例代码
PyTorch 后端示例:通过图匹配进行模型融合
此示例展示了如何通过pygmtools将不同的模型融合成一个单一模型。
模型融合的目的是将多个模型融合为一个,使得融合后的模型能够具有更高的性能。
神经网络可以被视为图(通道 - 节点,通道之间的更新函数 - 边;
节点特征 - 偏置,边特征 - 权重),而融合模型相当于解决一个图匹配
问题。在此示例中,给定的模型是在来自不同分布的MNIST数据上训练的,融合后的模型可以结合两个输入模型的知识,并在测试时达到更高的准确率。
# Author: Chang Liu <only-changer@sjtu.edu.cn>
# Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
注意
这是对Liu等人。深度神经网络融合通过图匹配应用于模型集成和联邦学习。ICML 2022.中思想的简化实现。 更多详情,请参阅论文和官方代码库。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm
pygm.set_backend('pytorch')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
定义一个简单的CNN分类器网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5, padding=1, padding_mode='replicate', bias=False)
self.max_pool = nn.MaxPool2d(2, padding=1)
self.conv2 = nn.Conv2d(32, 64, 5, padding=1, padding_mode='replicate', bias=False)
self.fc1 = nn.Linear(3136, 32, bias=False)
self.fc2 = nn.Linear(32, 10, bias=False)
def forward(self, x):
output = F.relu(self.conv1(x))
output = self.max_pool(output)
output = F.relu(self.conv2(output))
output = self.max_pool(output)
output = output.view(output.shape[0], -1)
output = self.fc1(output)
output = self.fc2(output)
return output
加载要融合的训练模型
model1 = SimpleNet()
model2 = SimpleNet()
model1.load_state_dict(torch.load('../data/example_model_fusion_1.dat', map_location=device))
model2.load_state_dict(torch.load('../data/example_model_fusion_2.dat', map_location=device))
model1.to(device)
model2.to(device)
test_dataset = torchvision.datasets.MNIST(
root='../data/mnist_data', # the directory to store the dataset
train=False, # the dataset is used to test
transform=transforms.ToTensor(), # the dataset is in the form of tensors
download=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=32,
shuffle=False)
打印简单CNN模型的层:
print(model1)
SimpleNet(
(conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
(max_pool): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
(fc1): Linear(in_features=3136, out_features=32, bias=False)
(fc2): Linear(in_features=32, out_features=10, bias=False)
)
测试输入模型
with torch.no_grad():
n_correct1 = 0
n_correct2 = 0
n_samples = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs1 = model1(images)
outputs2 = model2(images)
_, predictions1 = torch.max(outputs1, 1)
_, predictions2 = torch.max(outputs2, 1)
n_samples += labels.shape[0]
n_correct1 += (predictions1 == labels).sum().item()
n_correct2 += (predictions2 == labels).sum().item()
acc1 = 100 * n_correct1 / n_samples
acc2 = 100 * n_correct2 / n_samples
测试结果(两个独立的模型):
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.8%
构建图匹配的亲和矩阵
如下图所示,神经网络可以被视为图。权重对应于边的特征,而偏置对应于节点的特征。在这个例子中,神经网络没有偏置,因此只有边的特征。
plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()

定义图匹配亲和度度量函数
class Ground_Metric_GM:
def __init__(self,
model_1_param: torch.tensor = None,
model_2_param: torch.tensor = None,
conv_param: bool = False,
bias_param: bool = False,
pre_conv_param: bool = False,
pre_conv_image_size_squared: int = None):
self.model_1_param = model_1_param
self.model_2_param = model_2_param
self.conv_param = conv_param
self.bias_param = bias_param
# bias, or fully-connected from linear
if bias_param is True or (conv_param is False and pre_conv_param is False):
self.model_1_param = self.model_1_param.reshape(1, -1, 1)
self.model_2_param = self.model_2_param.reshape(1, -1, 1)
# fully-connected from conv
elif conv_param is False and pre_conv_param is True:
self.model_1_param = self.model_1_param.reshape(1, -1, pre_conv_image_size_squared)
self.model_2_param = self.model_2_param.reshape(1, -1, pre_conv_image_size_squared)
# conv
else:
self.model_1_param = self.model_1_param.reshape(1, -1, model_1_param.shape[-1])
self.model_2_param = self.model_2_param.reshape(1, -1, model_2_param.shape[-1])
def process_distance(self, p: int = 2):
return torch.cdist(
self.model_1_param.to(torch.float),
self.model_2_param.to(torch.float),
p=p)[0]
def process_soft_affinity(self, p: int = 2):
return torch.exp(0 - self.process_distance(p=p))
定义两个神经网络之间的亲和力函数。该函数接受多个神经网络模块,并构建相应的亲和力矩阵,该矩阵随后由图匹配求解器进一步处理。
def graph_matching_fusion(networks: list):
def total_node_num(network: torch.nn.Module):
# count the total number of nodes in the network [network]
num_nodes = 0
for idx, (name, parameters) in enumerate(network.named_parameters()):
if 'bias' in name:
continue
if idx == 0:
num_nodes += parameters.shape[1]
num_nodes += parameters.shape[0]
return num_nodes
n1 = total_node_num(network=networks[0])
n2 = total_node_num(network=networks[1])
assert (n1 == n2)
affinity = torch.zeros([n1 * n2, n1 * n2], device=device)
num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
num_nodes_before = 0
num_nodes_incremental = []
num_nodes_layers = []
pre_conv_list = []
cur_conv_list = []
conv_kernel_size_list = []
num_nodes_pre = 0
is_conv = False
pre_conv = False
pre_conv_out_channel = 1
is_final_bias = False
perm_is_complete = True
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
for idx, ((_, fc_layer0_weight), (_, fc_layer1_weight)) in \
enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
assert fc_layer0_weight.shape == fc_layer1_weight.shape
layer_shape = fc_layer0_weight.shape
num_nodes_cur = fc_layer0_weight.shape[0]
if len(layer_shape) > 1:
if is_conv is True and len(layer_shape) == 2:
num_nodes_pre = pre_conv_out_channel
else:
num_nodes_pre = fc_layer0_weight.shape[1]
if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
pre_bias = True
else:
pre_bias = False
if len(layer_shape) > 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = True
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data.view(
fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)
fc_layer1_weight_data = fc_layer1_weight.data.view(
fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)
elif len(layer_shape) == 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data
fc_layer1_weight_data = fc_layer1_weight.data
else:
is_bias = True
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data
fc_layer1_weight_data = fc_layer1_weight.data
if is_conv:
pre_conv_out_channel = num_nodes_cur
if is_bias is True and idx == num_layers - 1:
is_final_bias = True
if idx == 0:
for a in range(num_nodes_pre):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + a] \
[(num_nodes_before + a) * n2 + num_nodes_before + a] \
= 1
if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
for a in range(num_nodes_cur):
affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
= 1
if is_bias is False:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
else:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, 1)
layer_affinity = ground_metric.process_soft_affinity(p=2)
if is_bias is False:
pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
conv_kernel_size_list.append(pre_conv_kernel_size)
if is_bias is True and is_final_bias is False:
for a in range(num_nodes_cur):
for c in range(num_nodes_cur):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + c] \
[(num_nodes_before + a) * n2 + num_nodes_before + c] \
= layer_affinity[a][c]
elif is_final_bias is False:
for a in range(num_nodes_pre):
for b in range(num_nodes_cur):
affinity[
(num_nodes_before + a) * n2 + num_nodes_before:
(num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
= layer_affinity[a + b * num_nodes_pre].view(num_nodes_cur, num_nodes_pre).transpose(0, 1)
if is_bias is False:
num_nodes_before += num_nodes_pre
num_nodes_incremental.append(num_nodes_before)
num_nodes_layers.append(num_nodes_cur)
# affinity = (affinity + affinity.t()) / 2
return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]
获取model1和model2之间的亲和度(相似度)矩阵。
K, params = graph_matching_fusion([model1, model2])
通过图匹配对齐模型
通过图匹配算法最大化亲和度(相似度)来对齐 model1 和 model2 的通道。
n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)
项目 X 到神经网络匹配结果。神经网络匹配矩阵是通过将匈牙利算法应用于 X 的小块来构建的,因为只有来自同一神经网络层的通道才能匹配。
注意
在这个例子中,我们假设最后一个FC层是对齐的,不需要匹配。
new_X = torch.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = torch.eye(params[2][0], device=device)
for start_idx, length in zip(params[2][:-1], params[3][:-1]): # params[2] and params[3] are the indices of layers
slicing = slice(start_idx, start_idx + length)
new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = torch.eye(params[3][-1], device=device)
X = new_X
匹配结果的可视化。黑线分隔了不同层的通道。
plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
plt.axvline(x=idx, color='k')
plt.axhline(y=idx, color='k')

定义对齐函数:基于匹配结果融合模型
def align(solution, fusion_proportion, networks: list, params: list):
[_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
aligned_wt_0 = [parameter.data for name, parameter in named_weight_list_0]
idx = 0
num_layers = len(aligned_wt_0)
for num_before, num_cur, cur_conv, cur_kernel_size in \
zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
assert 'bias' not in named_weight_list_0[idx][0]
if len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (perm.transpose(0, 1).to(torch.float64) @
aligned_wt_0[idx].to(torch.float64).permute(2, 3, 0, 1)) \
.permute(2, 3, 0, 1)
else:
aligned_wt_0[idx] = perm.transpose(0, 1).to(torch.float64) @ aligned_wt_0[idx].to(torch.float64)
idx += 1
if idx >= num_layers:
continue
if 'bias' in named_weight_list_0[idx][0]:
aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
idx += 1
if idx >= num_layers:
continue
if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
.reshape(aligned_wt_0[idx].shape[0], 64, -1)
.permute(0, 2, 1)
@ perm.to(torch.float64)) \
.permute(0, 2, 1) \
.reshape(aligned_wt_0[idx].shape[0], -1)
elif len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
.permute(2, 3, 0, 1)
@ perm.to(torch.float64)) \
.permute(2, 3, 0, 1)
else:
aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
assert idx == num_layers
averaged_weights = []
for idx, parameter in enumerate(networks[1].parameters()):
averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx] + fusion_proportion * parameter)
return averaged_weights
测试融合模型
fusion_proportion 变量表示对新模型的贡献。例如,如果 fusion_proportion=0.2,
融合模型 = 80% model1 + 20% model2。
def align_model_and_test(X):
acc_list = []
for fusion_proportion in torch.arange(0, 1.1, 0.1):
fused_weights = align(X, fusion_proportion, [model1, model2], params)
fused_model = SimpleNet()
state_dict = fused_model.state_dict()
for idx, (key, _) in enumerate(state_dict.items()):
state_dict[key] = fused_weights[idx]
fused_model.load_state_dict(state_dict)
fused_model.to(device)
test_loss = 0
correct = 0
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = fused_model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
acc = 100. * correct / len(test_loader.dataset)
print(
f"{1 - fusion_proportion:.2f} model1 + {fusion_proportion:.2f} model2 -> fused model accuracy: {acc:.2f}%")
acc_list.append(acc)
return torch.tensor(acc_list)
print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.46%
0.80 model1 + 0.20 model2 -> fused model accuracy: 86.92%
0.70 model1 + 0.30 model2 -> fused model accuracy: 88.38%
0.60 model1 + 0.40 model2 -> fused model accuracy: 86.43%
0.50 model1 + 0.50 model2 -> fused model accuracy: 74.11%
0.40 model1 + 0.60 model2 -> fused model accuracy: 72.45%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.12%
0.20 model1 + 0.80 model2 -> fused model accuracy: 81.65%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.29%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80%
与普通模型融合(无匹配)相比,图匹配方法稳定了融合步骤:
print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(torch.eye(n1, device=device))
plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
plt.show()

No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.15%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.36%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.87%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.55%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80%
打印结果摘要
end_time = time.perf_counter()
print(f'time consumed for model fusion: {end_time - st_time:.2f} seconds')
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
print(f"best fused model accuracy: {torch.max(gm_acc_list):.2f}%")
time consumed for model fusion: 16.75 seconds
model1 accuracy = 84.18%, model2 accuracy = 83.8%
best fused model accuracy: 88.38%
注意
此示例支持GPU和CPU,在线文档是由仅支持CPU的机器构建的。 如果您在GPU上运行此代码,效率将显著提高。
脚本总运行时间: (0 分钟 18.178 秒)