热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch框架学习十一——网络层权值初始化

PyTorch框架学习十一——网络层权值初始化一、均匀分布初始化二、正态分布初始化三、常数初始化四、Xavier均匀分布初始化五、Xavier正态分布初始化六、kaiming均匀分


PyTorch框架学习十一——网络层权值初始化

  • 一、均匀分布初始化
  • 二、正态分布初始化
  • 三、常数初始化
  • 四、Xavier 均匀分布初始化
  • 五、Xavier正态分布初始化
  • 六、kaiming均匀分布初始化

前面的笔记介绍了网络模型的搭建,这次将介绍网络层权值的初始化,适当的初始化方法可以使得避免梯度消失或梯度爆炸等问题,还能一定程度上加快网络的训练迭代过程。

下面将介绍PyTorch中十种常用的权值初始化的方法:


一、均匀分布初始化

torch.nn.init.uniform_(tensor: torch.Tensor, a: float = 0.0, b: float = 1.0) → torch.Tensor

功能:将输入张量的值用均匀分布U(a,b)随机采样得到的值填充。

参数如下所示:
在这里插入图片描述


  1. tensor:要初始化的张量。
  2. a:均匀分布的下界。
  3. b:均匀分布的上界。

举个栗子:

>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)

二、正态分布初始化

torch.nn.init.normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) → torch.Tensor

功能:将输入张量的值用正态分布 N( mean, std ^2 )随机采样得到的值填充。

参数如下所示:
在这里插入图片描述


  1. tensor:要初始化的张量。
  2. mean:正态分布的均值。
  3. std:正态分布的标准差。

三、常数初始化

torch.nn.init.constant_(tensor: torch.Tensor, val: float) → torch.Tensor

功能:用固定值去填充张量。

参数如下:
在这里插入图片描述


  1. tensor:要填充的张量。
  2. val:要填充的值。

四、Xavier 均匀分布初始化

torch.nn.init.xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0) → torch.Tensor

功能:从下面这个均匀分布中随机采样初始化(具体看介绍Xavier的内容)
在这里插入图片描述
参数如下所示:
在这里插入图片描述


  1. tensor:要初始化的张量。
  2. gain:根据激活函数来定,保证网络层各层权重的方差差距不大。

举个栗子:

class MLP(nn.Module):def __init__(self, neural_num, layers):super(MLP, self).__init__()self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])self.neural_num = neural_numdef forward(self, x):for (i, linear) in enumerate(self.linears):x = linear(x)print("layer:{}, std:{}".format(i, x.std()))if torch.isnan(x.std()):print("output is nan in {} layers".format(i))breakreturn xdef initialize(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight.data)# flag = 0
flag = 1if flag:layer_nums = 100neural_nums = 256batch_size = 16net = MLP(neural_nums, layer_nums)net.initialize()inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1output = net(inputs)print(output)

构建了一个100层,每层有256个神经元的全连接神经网络,输出每一层网络的数据分布的标准差:

layer:0, std:0.9939432144165039
layer:1, std:0.988370954990387
layer:2, std:0.9993033409118652
layer:3, std:0.9946814179420471
layer:4, std:1.0136058330535889
layer:5, std:0.9804127812385559
layer:6, std:0.9861023426055908
layer:7, std:0.9943155646324158
layer:8, std:0.9847374558448792
layer:9, std:0.9681516885757446
layer:10, std:0.9731113910675049
layer:11, std:0.9867657423019409
layer:12, std:0.998853862285614
layer:13, std:0.9768239259719849
layer:14, std:0.980059027671814
layer:15, std:0.9851741790771484
layer:16, std:1.0022122859954834
layer:17, std:0.9788040518760681
layer:18, std:1.0017856359481812
layer:19, std:1.0342336893081665
layer:20, std:1.0184755325317383
layer:21, std:1.016075849533081
layer:22, std:0.9980445504188538
layer:23, std:1.0043185949325562
layer:24, std:0.9859704375267029
layer:25, std:0.9940337538719177
layer:26, std:1.0047379732131958
layer:27, std:1.0038164854049683
layer:28, std:1.0144379138946533
layer:29, std:1.0297335386276245
layer:30, std:1.0231270790100098
layer:31, std:0.9947567582130432
layer:32, std:1.0121735334396362
layer:33, std:1.0102561712265015
layer:34, std:1.0205620527267456
layer:35, std:1.0590678453445435
layer:36, std:1.0277358293533325
layer:37, std:1.0321041345596313
layer:38, std:1.0334043502807617
layer:39, std:1.0470187664031982
layer:40, std:1.0888501405715942
layer:41, std:1.063532829284668
layer:42, std:1.0635225772857666
layer:43, std:1.0936106443405151
layer:44, std:1.0897372961044312
layer:45, std:1.0780189037322998
layer:46, std:1.1132346391677856
layer:47, std:1.1005138158798218
layer:48, std:1.0610020160675049
layer:49, std:1.114995002746582
layer:50, std:1.107061743736267
layer:51, std:1.1147115230560303
layer:52, std:1.1051268577575684
layer:53, std:1.0692596435546875
layer:54, std:1.059423565864563
layer:55, std:1.0318952798843384
layer:56, std:1.0445512533187866
layer:57, std:1.038772463798523
layer:58, std:1.0729072093963623
layer:59, std:1.0931061506271362
layer:60, std:1.102836012840271
layer:61, std:1.0710251331329346
layer:62, std:1.0685100555419922
layer:63, std:1.0235627889633179
layer:64, std:1.0192655324935913
layer:65, std:1.0483664274215698
layer:66, std:1.033905267715454
layer:67, std:1.0418909788131714
layer:68, std:1.0399161577224731
layer:69, std:1.0536786317825317
layer:70, std:1.041662573814392
layer:71, std:1.0555484294891357
layer:72, std:1.0822663307189941
layer:73, std:1.0788710117340088
layer:74, std:1.1118624210357666
layer:75, std:1.0804673433303833
layer:76, std:1.0754098892211914
layer:77, std:1.0847842693328857
layer:78, std:1.0808136463165283
layer:79, std:1.0306202173233032
layer:80, std:1.0064393281936646
layer:81, std:1.0131638050079346
layer:82, std:1.023984670639038
layer:83, std:1.005560040473938
layer:84, std:0.9921131134033203
layer:85, std:0.9612709879875183
layer:86, std:0.957591712474823
layer:87, std:0.952028751373291
layer:88, std:0.9482743144035339
layer:89, std:0.9498487114906311
layer:90, std:0.9595613479614258
layer:91, std:0.9428602457046509
layer:92, std:0.9281052350997925
layer:93, std:0.8957657814025879
layer:94, std:0.9068138003349304
layer:95, std:0.8488100171089172
layer:96, std:0.8666995763778687
layer:97, std:0.8959987759590149
layer:98, std:0.8925248980522156
layer:99, std:0.8857517242431641

可以看出是基本在1附近的,这样既不会梯度消失也不会梯度爆炸。


五、Xavier正态分布初始化

torch.nn.init.xavier_normal_(tensor: torch.Tensor, gain: float = 1.0) → torch.Tensor

功能:从这个正态分布N(0, std^2)中随机采样初始化(具体看介绍Xavier的内容),其中:
在这里插入图片描述

参数如下所示:
在这里插入图片描述


六、kaiming均匀分布初始化

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

功能:从均匀分布 U(-bound, bound) 中随机采样初始化(具体看介绍kaiming的内容:Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015))
其中:
在这里插入图片描述


推荐阅读
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文介绍了在MFC下利用C++和MFC的特性动态创建窗口的方法,包括继承现有的MFC类并加以改造、插入工具栏和状态栏对象的声明等。同时还提到了窗口销毁的处理方法。本文详细介绍了实现方法并给出了相关注意事项。 ... [详细]
  • 上图是InnoDB存储引擎的结构。1、缓冲池InnoDB存储引擎是基于磁盘存储的,并将其中的记录按照页的方式进行管理。因此可以看作是基于磁盘的数据库系统。在数据库系统中,由于CPU速度 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • 电话号码的字母组合解题思路和代码示例
    本文介绍了力扣题目《电话号码的字母组合》的解题思路和代码示例。通过使用哈希表和递归求解的方法,可以将给定的电话号码转换为对应的字母组合。详细的解题思路和代码示例可以帮助读者更好地理解和实现该题目。 ... [详细]
  • Java序列化对象传给PHP的方法及原理解析
    本文介绍了Java序列化对象传给PHP的方法及原理,包括Java对象传递的方式、序列化的方式、PHP中的序列化用法介绍、Java是否能反序列化PHP的数据、Java序列化的原理以及解决Java序列化中的问题。同时还解释了序列化的概念和作用,以及代码执行序列化所需要的权限。最后指出,序列化会将对象实例的所有字段都进行序列化,使得数据能够被表示为实例的序列化数据,但只有能够解释该格式的代码才能够确定数据的内容。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • JDK源码学习之HashTable(附带面试题)的学习笔记
    本文介绍了JDK源码学习之HashTable(附带面试题)的学习笔记,包括HashTable的定义、数据类型、与HashMap的关系和区别。文章提供了干货,并附带了其他相关主题的学习笔记。 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
author-avatar
U友50054453
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有