热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch提取权重_PyTorch使用预训练模型

PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,

PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,如有错误恳请指正。

一、直接加载预训练模型进行训练

1、加载保存的整个模型

torch.save(model,'model.pkl')
...
model = torch.load('model.pkl')

2、加载保存的模型参数

torch.save(model.state_dict(),'model_state_dict.pkl')
...
model.load_state_dict(torch.load('model_state_dict.pkl'))

关于模型的保存和加载,可以详细参照我的这篇文章:

HUST小菜鸡:Pytorch搭建简单神经网络(三)——快速搭建、保存与提取​zhuanlan.zhihu.com
daa869c16870e7c2243d0b373f66bb1c.png

通过对模型参数的保存的解析,我们可以深入的了解

load_dict = torch.load('models/cifar10_statedict.pkl')
print(load_dict.keys())
print(type(load_dict))

输出的结果如下所示:

odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'conv3.0.weight', 'conv3.0.bias', 'conv4.0.weight', 'conv4.0.bias', 'conv5.0.weight', 'conv5.0.bias', 'conv6.0.weight', 'conv6.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.5.weight', 'classifier.5.bias'])

可以看出保存的state_dict其实是一个collections.OrderedDict的Object,和普通的dict不同的是,该类别是有着严格的顺序,而dict中的元素是没有严格的顺序。

但是有一个问题值得深入考量——两个网络的结构是一样的,但是结构的命名是不一样的,那么对于这种模型的加载,如果不一样的话会出现报错,该如何解决

参照以上结果的输出,state_dict中key就是网络结构的名称,所以当网络结构一样的时候,只需要修改索引key,就可以解决以上的问题,至于如何修改可以参照如下方式:

https://stackoverflow.com/questions/12150872/change-key-in-ordereddict-without-losing-order​stackoverflow.com

二、加载部分预训练模型

我们经常对现有的经典网络进行如下操作,我们不修改网络的主体部分,我们只修改网络的输出,或者在最后加上一些网络层来达到我们想要的输出结果,虽然很难保证网络模型和某些公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

model = cifar10_cnn.CIFAR10_Nettest()
pretrained_dict = torch.load('models/cifar10_statedict.pkl')
model_dict = model.state_dict()print('随机初始化权重第一层:',model_dict['conv1.0.weight'])# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])# 更新现有的model_dict
model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型
model.load_state_dict(model_dict)print('更新后权重第一层:',model_dict['conv1.0.weight'])

输出的部分结果如下所示,为了直观显示我只截取了中间的某一部分

随机初始化权重第一层: tensor([[[[ 0.0142, 0.1039, 0.1260],[ 0.1805, -0.0533, 0.0007],[-0.1032, -0.1039, -0.0633]],[[ 0.0714, -0.0053, 0.0059],[-0.0528, 0.0438, -0.1108],[ 0.0544, 0.0157, 0.1265]],预训练权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],[-2.3942e-01, -1.5474e-01, 1.3142e-01],[-9.4602e-02, 6.4120e-02, -9.4336e-02]],[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],[-5.8471e-02, -8.8146e-02, -1.6053e-01],[-1.0788e-01, -5.9083e-02, -9.0651e-02]],更新后权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],[-2.3942e-01, -1.5474e-01, 1.3142e-01],[-9.4602e-02, 6.4120e-02, -9.4336e-02]],[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],[-5.8471e-02, -8.8146e-02, -1.6053e-01],[-1.0788e-01, -5.9083e-02, -9.0651e-02]],

可以看出该方法可以实现对模型中相同的部分进行修改

在调试过程中我遇到了一个错误

4501720779c45ba2a914c982910f49b4.png

在模型剔除操作中,只比较了该state_dict的key值,而不是比较网络层的形状,两个网络我修改了网络的最后的用于预测的全连接层,通过报错内容可以看出来是两个权重的大小是不匹配的,所以我们的新模型改变了的层需要和原模型对应层的名字不一样,才可以保证该方法的可行。这里我加了一个小小的1将其区分,解决了这个不匹配的问题。

35fe863b90789905cb7cc2281804029d.png
668e7d9762a3aa361869713d2cbf5836.png

三、冻结部分参数,训练另外的一部分参数

当输入模型的数据集在形式上是基本相似的,或者是同种类型的,例如用于计算机视觉任务的特征提取网络,对于某一特定任务已经有该特征提取网络的预训练权重,那么在对后面的部分进行训练的过程我在训练这前面的部分就显得有点浪费资源了,而且训练时间也会变长。如果给这前面部分的参数冻结,只训练后面部分的网络参数,可以实现资源配置和时间效率上的双赢。

通过修改requires_grad为False来冻结网络参数

在网络模型中需要冻结之前的位置添加如下:

for p in self.parameters():p.requires_grad = False

如下所示:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)for p in self.parameters():p.requires_grad=Falseself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)

同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())过滤掉requires_grad=false的层

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

如果网络参数保存在OrderedDict中,可以通过查找参数对应的序号进而实现对其冻结

model = cifar10_cnn.CIFAR10_Nettest()
pretrained_dict = torch.load('models/cifar10_statedict.pkl')
model_dict = model.state_dict()#输出模型的最后三层的参数
print(list(model.parameters())[-3:])dict_name = list(model_dict)
for i, p in enumerate(dict_name):print(i, p)#对模型的前16层参数进行冻结
for i,p in enumerate(model.parameters()):if i <16:p.requires_grad &#61; False#输出冻结后模型的最后三层的参数
print(list(model.parameters())[-3:])

输出的结果如下所示&#xff1a;

[Parameter containing:
tensor([-0.0016, 0.0069, 0.0108, 0.0107, -0.0128, -0.0033, 0.0045, -0.0061,0.0194, -0.0015, -0.0055, -0.0005, -0.0218, -0.0059, 0.0047, 0.0190,0.0046, -0.0166, -0.0004, 0.0004, 0.0058, -0.0016, -0.0053, -0.0188,0.0032, -0.0100, -0.0156, -0.0149, 0.0119, -0.0011, 0.0116, -0.0167,0.0047, 0.0037, -0.0020, 0.0080, 0.0027, 0.0003, 0.0073, 0.0070,-0.0128, -0.0208, 0.0117, 0.0220, -0.0133, -0.0021, 0.0214, 0.0109,, ...,0.0016, 0.0144, 0.0097, 0.0133, -0.0026, 0.0196, 0.0056, 0.0069,-0.0113, 0.0184, 0.0202, 0.0016, -0.0200, 0.0198, -0.0017, -0.0141,0.0219, 0.0120, 0.0124, -0.0168, -0.0105, -0.0165, 0.0162, 0.0146,0.0098, -0.0133, -0.0192, -0.0135, 0.0196, 0.0095, -0.0193, -0.0068],requires_grad&#61;True), Parameter containing:
tensor([[ 1.9555e-02, -1.2806e-02, 2.3172e-02, ..., 3.8184e-02,2.5622e-02, 1.9850e-02],[-2.9452e-02, -3.3035e-02, -3.2527e-02, ..., -3.0232e-02,-5.3696e-05, -3.3424e-02],[ 3.9043e-02, 1.3163e-02, -3.7559e-02, ..., 9.0075e-03,3.5016e-02, 1.0584e-03],...,[ 4.0777e-02, 3.1920e-02, -3.3931e-02, ..., 2.9741e-02,-3.8361e-02, -3.7472e-02],[-2.5555e-03, -1.2358e-02, -7.5636e-03, ..., 2.1639e-02,-1.6167e-02, -1.5543e-02],[ 3.8254e-02, 1.5340e-02, 1.9038e-02, ..., 2.3954e-02,-7.7485e-03, -3.5717e-02]], requires_grad&#61;True), Parameter containing:
tensor([-0.0284, 0.0027, -0.0009, 0.0059, -0.0242, -0.0352, 0.0366, -0.0413,-0.0103, -0.0325], requires_grad&#61;True)]
0 conv1.0.weight
1 conv1.0.bias
2 conv2.0.weight
3 conv2.0.bias
4 conv3.0.weight
5 conv3.0.bias
6 conv4.0.weight
7 conv4.0.bias
8 conv5.0.weight
9 conv5.0.bias
10 conv6.0.weight
11 conv6.0.bias
12 classifier1.1.weight
13 classifier1.1.bias
14 classifier1.3.weight
15 classifier1.3.bias
16 classifier1.5.weight
17 classifier1.5.bias
[Parameter containing:
tensor([-0.0016, 0.0069, 0.0108, 0.0107, -0.0128, -0.0033, 0.0045, -0.0061,0.0194, -0.0015, -0.0055, -0.0005, -0.0218, -0.0059, 0.0047, 0.0190,0.0046, -0.0166, -0.0004, 0.0004, 0.0058, -0.0016, -0.0053, -0.0188,0.0032, -0.0100, -0.0156, -0.0149, 0.0119, -0.0011, 0.0116, -0.0167,0.0047, 0.0037, -0.0020, 0.0080, 0.0027, 0.0003, 0.0073, 0.0070,, ...,0.0027, -0.0195, -0.0137, -0.0025, -0.0087, -0.0100, -0.0130, -0.0030,0.0013, -0.0040, -0.0150, 0.0023, 0.0158, -0.0037, -0.0151, 0.0105,0.0016, 0.0144, 0.0097, 0.0133, -0.0026, 0.0196, 0.0056, 0.0069,-0.0113, 0.0184, 0.0202, 0.0016, -0.0200, 0.0198, -0.0017, -0.0141,0.0219, 0.0120, 0.0124, -0.0168, -0.0105, -0.0165, 0.0162, 0.0146,0.0098, -0.0133, -0.0192, -0.0135, 0.0196, 0.0095, -0.0193, -0.0068]), Parameter containing:
tensor([[ 1.9555e-02, -1.2806e-02, 2.3172e-02, ..., 3.8184e-02,2.5622e-02, 1.9850e-02],[-2.9452e-02, -3.3035e-02, -3.2527e-02, ..., -3.0232e-02,-5.3696e-05, -3.3424e-02],[ 3.9043e-02, 1.3163e-02, -3.7559e-02, ..., 9.0075e-03,3.5016e-02, 1.0584e-03],, ...,[ 4.0777e-02, 3.1920e-02, -3.3931e-02, ..., 2.9741e-02,-3.8361e-02, -3.7472e-02],[-2.5555e-03, -1.2358e-02, -7.5636e-03, ..., 2.1639e-02,-1.6167e-02, -1.5543e-02],[ 3.8254e-02, 1.5340e-02, 1.9038e-02, ..., 2.3954e-02,-7.7485e-03, -3.5717e-02]], requires_grad&#61;True), Parameter containing:
tensor([-0.0284, 0.0027, -0.0009, 0.0059, -0.0242, -0.0352, 0.0366, -0.0413,-0.0103, -0.0325], requires_grad&#61;True)]Process finished with exit code 0

对比输出结果可以发现&#xff0c;通过该方法实现了对网络参数的冻结&#xff0c;即requires_grad参数被冻结的由True变成了False

e9adbbc7b39367fcafc98aa95d57b458.png
5dd7bea758f4ca54c01325df88222396.png

三、改动网络模型

现有的一些经典网络如AlexNet&#xff0c;ResNet等&#xff0c;对于ImageNet这些尺寸输入为224*224&#xff0c;但是如果是自定义的一些数据集或者是使用其他的现有数据集进行测试的时候&#xff0c;不可避免的会需要调节网络中的网络参数&#xff0c;改动网络模型来适应输入形式。

以AlexNet为例&#xff1a;

import torchvision.models as modelsmodel &#61; models.AlexNet()
print(model)
#修改网络的第一个卷积层的输入为4通道&#xff0c;输出的结果预测为10个类别
model.features[0]&#61;nn.Conv2d(4, 64, kernel_size&#61;(11, 11), stride&#61;(4, 4), padding&#61;(2, 2))
model.classifier[6] &#61; nn.Linear(4096,10)print(model)

结果如下所示&#xff08;有变化的部分被加粗标出&#xff09;:

AlexNet((features): Sequential((0): Conv2d(3, 64, kernel_size&#61;(11, 11), stride&#61;(4, 4), padding&#61;(2, 2))(1): ReLU(inplace&#61;True)(2): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(3): Conv2d(64, 192, kernel_size&#61;(5, 5), stride&#61;(1, 1), padding&#61;(2, 2))(4): ReLU(inplace&#61;True)(5): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(6): Conv2d(192, 384, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(7): ReLU(inplace&#61;True)(8): Conv2d(384, 256, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(9): ReLU(inplace&#61;True)(10): Conv2d(256, 256, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(11): ReLU(inplace&#61;True)(12): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False))(avgpool): AdaptiveAvgPool2d(output_size&#61;(6, 6))(classifier): Sequential((0): Dropout(p&#61;0.5, inplace&#61;False)(1): Linear(in_features&#61;9216, out_features&#61;4096, bias&#61;True)(2): ReLU(inplace&#61;True)(3): Dropout(p&#61;0.5, inplace&#61;False)(4): Linear(in_features&#61;4096, out_features&#61;4096, bias&#61;True)(5): ReLU(inplace&#61;True)(6): Linear(in_features&#61;4096, out_features&#61;1000, bias&#61;True))
)
AlexNet((features): Sequential((0): Conv2d(4, 64, kernel_size&#61;(11, 11), stride&#61;(4, 4), padding&#61;(2, 2))(1): ReLU(inplace&#61;True)(2): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(3): Conv2d(64, 192, kernel_size&#61;(5, 5), stride&#61;(1, 1), padding&#61;(2, 2))(4): ReLU(inplace&#61;True)(5): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(6): Conv2d(192, 384, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(7): ReLU(inplace&#61;True)(8): Conv2d(384, 256, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(9): ReLU(inplace&#61;True)(10): Conv2d(256, 256, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(11): ReLU(inplace&#61;True)(12): MaxPool2d(kernel_size&#61;3, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False))(avgpool): AdaptiveAvgPool2d(output_size&#61;(6, 6))(classifier): Sequential((0): Dropout(p&#61;0.5, inplace&#61;False)(1): Linear(in_features&#61;9216, out_features&#61;4096, bias&#61;True)(2): ReLU(inplace&#61;True)(3): Dropout(p&#61;0.5, inplace&#61;False)(4): Linear(in_features&#61;4096, out_features&#61;4096, bias&#61;True)(5): ReLU(inplace&#61;True)(6): Linear(in_features&#61;4096, out_features&#61;10, bias&#61;True))
)Process finished with exit code 0

这些开源的框架已经包含了预训练模型&#xff0c;我们只需要修改网络模型&#xff0c;其加载预训练模型的方法和之前的提到的方式是一样的&#xff0c;当网络中模型被修改的时候&#xff0c;模型参数会保留原有未改变网络的参数不变&#xff0c;而对于改变了模型的网络层参数进行随机初始化

03b5fc0118408be11bd3a43727f8755c.png

model &#61; models.alexnet(pretrained&#61;True)
#通过该方式修改网络的最后的输出
model.classifier[6] &#61; nn.Linear(4096,10)

727110c9d0d8bffaca3fcccddd221b56.png
下载预训练模型

通过实验验证了修改模型中的某一层的时候&#xff0c;网络其他层的参数加载还是和原有的参数是一样的&#xff0c;但是对于修改的层的网络参数会进行随机初始化进行后续的训练。

model &#61; cifar10_cnn.CIFAR10_Nettest()
pretrained_dict &#61; torch.load(&#39;models/cifar10_statedict.pkl&#39;)
model_dict &#61; model.state_dict()
pretrained_dict &#61; {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)print(model)
new_model_dict &#61; model.state_dict()
dict_name &#61; list(new_model_dict)
for i, p in enumerate(dict_name):print(i, p)print(&#39;before change:n&#39;,new_model_dict[&#39;classifier.5.bias&#39;])
model.classifier[5]&#61;nn.Linear(1024,17)change_model_dict &#61; model.state_dict()
new_dict_name &#61; list(change_model_dict)
print(&#39;after change:n&#39;,change_model_dict[&#39;classifier.5.bias&#39;])

实验结果如下&#xff08;重点部分被我加粗标出&#xff09;

D:Pycharmcifar10_classifiedvenvScriptspython.exe D:/Pycharm/cifar10_classified/pretrain_lock.py
CIFAR10_Nettest((conv1): Sequential((0): Conv2d(3, 16, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(1): ReLU())(conv2): Sequential((0): Conv2d(16, 32, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(1): MaxPool2d(kernel_size&#61;2, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(2): ReLU())(conv3): Sequential((0): Conv2d(32, 64, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(1): ReLU())(conv4): Sequential((0): Conv2d(64, 128, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(1): MaxPool2d(kernel_size&#61;2, stride&#61;2, padding&#61;0, dilation&#61;1, ceil_mode&#61;False)(2): ReLU())(conv5): Sequential((0): Conv2d(128, 256, kernel_size&#61;(3, 3), stride&#61;(1, 1), padding&#61;(1, 1))(1): ReLU())(conv6): Sequential((0): Conv2d(256, 256, kernel_size&#61;(3, 3), stride&#61;(2, 2), padding&#61;(2, 2))(1): ReLU())(classifier): Sequential((0): Dropout(p&#61;0.5, inplace&#61;False)(1): Linear(in_features&#61;6400, out_features&#61;2048, bias&#61;True)(2): ReLU()(3): Linear(in_features&#61;2048, out_features&#61;1024, bias&#61;True)(4): ReLU()(5): Linear(in_features&#61;1024, out_features&#61;10, bias&#61;True))
)
0 conv1.0.weight
1 conv1.0.bias
2 conv2.0.weight
3 conv2.0.bias
4 conv3.0.weight
5 conv3.0.bias
6 conv4.0.weight
7 conv4.0.bias
8 conv5.0.weight
9 conv5.0.bias
10 conv6.0.weight
11 conv6.0.bias
12 classifier.1.weight
13 classifier.1.bias
14 classifier.3.weight
15 classifier.3.bias
16 classifier.5.weight
17 classifier.5.biasbefore change:tensor([ 0.1432, -0.3336, -0.1030, 0.1301, 0.1653, -0.0449, -0.0391, -0.0788,0.0337, -0.0665])
after change:tensor([ 0.0105, -0.0262, 0.0223, -0.0275, 0.0025, -0.0059, 0.0214, -0.0082,0.0023, -0.0023, 0.0252, -0.0054, -0.0039, 0.0251, 0.0066, 0.0187,-0.0063])#第二次实验结果&#xff08;冗余结果没有显示&#xff09;
before change:tensor([ 0.1432, -0.3336, -0.1030, 0.1301, 0.1653, -0.0449, -0.0391, -0.0788,0.0337, -0.0665])
after change:tensor([ 0.0277, 0.0027, 0.0234, -0.0041, 0.0221, -0.0219, 0.0222, 0.0001,-0.0195, 0.0178, 0.0011, 0.0146, 0.0026, -0.0046, -0.0154, -0.0231,-0.0281])




推荐阅读
author-avatar
丫头片子ZXH
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有