热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch自由载入部分模型参数并冻结

Pytorch的load方法和load_state_dict方法只能较为固定的读入参数文件,他们要求读入的state_dict的key和Model.state_dict()的key

Pytorch的load方法和load_state_dict方法只能较为固定的读入参数文件,他们要求读入的state_dict的key和Model.state_dict()的key对应相等。

而我们在进行迁移学习的过程中也许只需要使用某个预训练网络的一部分,把多个网络拼和成一个网络,或者为了得到中间层的输出而分离预训练模型中的Sequential 等等,这些情况下。传统的load方法就不是很有效了。

例如,我们想利用Mobilenet的前7个卷积并把这几层冻结,后面的部分接别的结构,或者改写成FCN结构,传统的方法就不奏效了。

最普适的方法是:构建一个字典,使得字典的keys和我们自己创建的网络相同,我们再从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,目前只能想到这个方法应对较为复杂的网络变换。

网上查“载入部分模型”,“冻结部分模型”一般都是只改个FC,根本没有用,初学的时候自己写state_dict也踩了一些坑,发出来记录一下。

一.载入部分预训练参数

我们先看看Mobilenet的结构

( 来源github,附带预训练模型mobilenet_sgd_rmsprop_69.526.tar)

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
self.model = nn.Sequential(
conv_bn( 3, 32, 2),
conv_dw( 32, 64, 1),
conv_dw( 64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000)
def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x

我们只需要前7层卷积,并且为了方便日后concate操作,我们把Sequential拆开,成为下面的样子

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)

self.conv1 = conv_bn( 3, 32, 2)
self.conv2 = conv_dw( 32, 64, 1)
self.conv3 = conv_dw( 64, 128, 2)
self.conv4 = conv_dw(128, 128, 1)
self.conv5 = conv_dw(128, 256, 2)
self.conv6 = conv_dw(256, 256, 1)
self.conv7 = conv_dw(256, 512, 2)

# 原来这些不要了
# 可以自己接后面的结构
''' self.features = nn.Sequential( conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 1024, 2), conv_dw(1024, 1024, 1), nn.AvgPool2d(7),) self.fc = nn.Linear(1024, 1000) '''

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
#x8 = self.features(x7)
#out = self.fc
return (x1,x2,x3,x4,x4,x6,x7)

我们更具改过的结构创建一个net,看看他的state_dict和我们预训练文件的state_dict有啥区别

net = Net()
#我的电脑没有GPU,他的参数是GPU训练的cudatensor,于是要下面这样转换一下
dict_trained = torch.load("mobilenet_sgd_rmsprop_69.526.tar",map_location=lambda storage, loc: storage)["state_dict"]
dict_new = net.state_dict().copy()
new_list = list (net.state_dict().keys() )
trained_list = list (dict_trained.keys() )
print("new_state_dict size: {} trained state_dict size: {}".format(len(new_list),len(trained_list)) )
print("New state_dict first 10th parameters names")
print(new_list[:10])
print("trained state_dict first 10th parameters names")
print(trained_list[:10])
print(type(dict_new))
print(type(dict_trained))

得到输出如下:

我们截断一半之后,参数由137变成65了,前十个参数看出,名字变了但是顺序其实没变。state_dict的数据类型是Odict,可以按照dict的操作方法操作。

new_state_dict size: 65 trained state_dict size: 137

New state_dict first 10th parameters names

[‘conv1.0.weight’, ‘conv1.1.weight’, ‘conv1.1.bias’, ‘conv1.1.running_mean’, ‘conv1.1.running_var’, ‘conv2.0.weight’, ‘conv2.1.weight’, ‘conv2.1.bias’, ‘conv2.1.running_mean’, ‘conv2.1.running_var’]

trained state_dict first 10th parameters names

[‘module.model.0.0.weight’, ‘module.model.0.1.weight’, ‘module.model.0.1.bias’, ‘module.model.0.1.running_mean’, ‘module.model.0.1.running_var’, ‘module.model.1.0.weight’, ‘module.model.1.1.weight’, ‘module.model.1.1.bias’, ‘module.model.1.1.running_mean’, ‘module.model.1.1.running_var’]

我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。

for i in range(65):
dict_new[ new_list[i] ] = dict_trained[ trained_list[i] ]
net.load_state_dict(dict_new)

还有别的情况,比如我们只是在后面加了一些层,没有改变原来网络层的名字和结构,可以用下面的简便方法:

loaded_dict = {k: loaded_dict[k] for k, _ in model.state_dict()}

二.冻结这几层参数

方法很多,这里用和上面方法对应的冻结方法

发现之前的冻结有问题,还是建议看一下
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
或者
https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714
或者

对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是

optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )


推荐阅读
  • 本文介绍了一个使用Slideview组件实现循环轮播效果的例子,并将其作为ListView顶部的一项。此ListView包含了两种不同的模板设计,一种以Slideview为核心,另一种则是标准的单元格模板,包含按钮和标签。 ... [详细]
  • 本文将深入探讨 Unreal Engine 4 (UE4) 中的距离场技术,包括其原理、实现细节以及在渲染中的应用。距离场技术在现代游戏引擎中用于提高光照和阴影的效果,尤其是在处理复杂几何形状时。文章将结合具体代码示例,帮助读者更好地理解和应用这一技术。 ... [详细]
  • 本文探讨了在SQL Server中处理几何类型列时遇到的INTERSECT操作限制,并提供了解决方案,包括通过转换数据类型和使用额外表结构的方法。 ... [详细]
  • 本文将从基础概念入手,详细探讨SpringMVC框架中DispatcherServlet如何通过HandlerMapping进行请求分发,以及其背后的源码实现细节。 ... [详细]
  • 小编给大家分享一下Vue3中如何提高开发效率,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获, ... [详细]
  • Logging all MySQL queries into the Slow Log
    MySQLoptionallylogsslowqueriesintotheSlowQueryLog–orjustSlowLog,asfriendscallit.However,Thereareseveralreasonstologallqueries.Thislistisnotexhaustive:Belowyoucanfindthevariablestochange,astheyshouldbewritteninth ... [详细]
  • 本文详细介绍如何在华为鲲鹏平台上构建和使用适配ARM架构的Redis Docker镜像,解决常见错误并提供优化建议。 ... [详细]
  • 协程作为一种并发设计模式,能有效简化Android平台上的异步代码处理。自Kotlin 1.3版本引入协程以来,这一特性基于其他语言的成熟理念,为开发者提供了新的工具,以增强应用的响应性和效率。 ... [详细]
  • 在Ubuntu 18.04上使用Nginx搭建RTMP流媒体服务器
    本文详细介绍了如何在Ubuntu 18.04上使用Nginx和nginx-rtmp-module模块搭建RTMP流媒体服务器,包括环境搭建、配置文件修改和推流拉流操作。适用于需要搭建流媒体服务器的技术人员。 ... [详细]
  • 在Java开发中,保护代码安全是一个重要的课题。由于Java字节码容易被反编译,因此使用代码混淆工具如ProGuard变得尤为重要。本文将详细介绍如何使用ProGuard进行代码混淆,以及其基本原理和常见问题。 ... [详细]
  • RTThread线程间通信
    线程中通信在裸机编程中,经常会使用全局变量进行功能间的通信,如某些功能可能由于一些操作而改变全局变量的值,另一个功能对此全局变量进行读取& ... [详细]
  • pypy 真的能让 Python 比 C 还快么?
    作者:肖恩顿来源:游戏不存在最近“pypy为什么能让python比c还快”刷屏了,原文讲的内容偏理论,干货比较少。我们可以再深入一点点,了解pypy的真相。正式开始之前,多唠叨两句 ... [详细]
  • 本文详细介绍了如何在 CentOS 7 及其衍生发行版(如 Red Hat, Oracle, Scientific Linux 7)上安装和完全卸载 GitLab。包括安装必要的依赖关系、配置防火墙、安装 GitLab 软件包以及常见问题的解决方法。 ... [详细]
  • 本打算教一步步实现koa-router,因为要解释的太多了,所以先简化成mini版本,从实现部分功能到阅读源码,希望能让你好理解一些。希望你之前有读过koa源码,没有的话,给你链接 ... [详细]
  • ipsec 加密流程(二):ipsec初始化操作
    《openswan》专栏系列文章主要是记录openswan源码学习过程中的笔记。Author:叨陪鲤Email:vip_13031075266163.comDate:2020.1 ... [详细]
author-avatar
半暖半夏半流年
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有