PyTorch框架学习十一——网络层权值初始化一、均匀分布初始化二、正态分布初始化三、常数初始化四、Xavier均匀分布初始化五、Xavier正态分布初始化六、kaiming均匀分
PyTorch框架学习十一——网络层权值初始化 一、均匀分布初始化 二、正态分布初始化 三、常数初始化 四、Xavier 均匀分布初始化 五、Xavier正态分布初始化 六、kaiming均匀分布初始化
前面的笔记介绍了网络模型的搭建,这次将介绍网络层权值的初始化,适当的初始化方法可以使得避免梯度消失或梯度爆炸等问题,还能一定程度上加快网络的训练迭代过程。
下面将介绍PyTorch中十种常用的权值初始化的方法:
一、均匀分布初始化 torch. nn. init. uniform_( tensor: torch. Tensor, a: float = 0.0 , b: float = 1.0 ) → torch. Tensor
功能:将输入张量的值用均匀分布U(a,b)随机采样得到的值填充。
参数如下所示:
tensor :要初始化的张量。a :均匀分布的下界。b :均匀分布的上界。 举个栗子:
>> > w = torch. empty( 3 , 5 ) >> > nn. init. uniform_( w)
二、正态分布初始化 torch. nn. init. normal_( tensor: torch. Tensor, mean: float = 0.0 , std: float = 1.0 ) → torch. Tensor
功能:将输入张量的值用正态分布 N( mean, std ^2 )随机采样得到的值填充。
参数如下所示:
tensor :要初始化的张量。mean :正态分布的均值。std :正态分布的标准差。 三、常数初始化 torch. nn. init. constant_( tensor: torch. Tensor, val: float ) → torch. Tensor
功能:用固定值去填充张量。
参数如下:
tensor :要填充的张量。val :要填充的值。 四、Xavier 均匀分布初始化 torch. nn. init. xavier_uniform_( tensor: torch. Tensor, gain: float = 1.0 ) → torch. Tensor
功能:从下面这个均匀分布中随机采样初始化(具体看介绍Xavier的内容) 参数如下所示:
tensor :要初始化的张量。gain :根据激活函数来定,保证网络层各层权重的方差差距不大。 举个栗子:
class MLP ( nn. Module) : def __init__ ( self, neural_num, layers) : super ( MLP, self) . __init__( ) self. linears = nn. ModuleList( [ nn. Linear( neural_num, neural_num, bias= False ) for i in range ( layers) ] ) self. neural_num = neural_numdef forward ( self, x) : for ( i, linear) in enumerate ( self. linears) : x = linear( x) print ( "layer:{}, std:{}" . format ( i, x. std( ) ) ) if torch. isnan( x. std( ) ) : print ( "output is nan in {} layers" . format ( i) ) break return xdef initialize ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight. data) flag = 1 if flag: layer_nums = 100 neural_nums = 256 batch_size = 16 net = MLP( neural_nums, layer_nums) net. initialize( ) inputs = torch. randn( ( batch_size, neural_nums) ) output = net( inputs) print ( output)
构建了一个100层,每层有256个神经元的全连接神经网络,输出每一层网络的数据分布的标准差:
layer: 0 , std: 0.9939432144165039 layer: 1 , std: 0.988370954990387 layer: 2 , std: 0.9993033409118652 layer: 3 , std: 0.9946814179420471 layer: 4 , std: 1.0136058330535889 layer: 5 , std: 0.9804127812385559 layer: 6 , std: 0.9861023426055908 layer: 7 , std: 0.9943155646324158 layer: 8 , std: 0.9847374558448792 layer: 9 , std: 0.9681516885757446 layer: 10 , std: 0.9731113910675049 layer: 11 , std: 0.9867657423019409 layer: 12 , std: 0.998853862285614 layer: 13 , std: 0.9768239259719849 layer: 14 , std: 0.980059027671814 layer: 15 , std: 0.9851741790771484 layer: 16 , std: 1.0022122859954834 layer: 17 , std: 0.9788040518760681 layer: 18 , std: 1.0017856359481812 layer: 19 , std: 1.0342336893081665 layer: 20 , std: 1.0184755325317383 layer: 21 , std: 1.016075849533081 layer: 22 , std: 0.9980445504188538 layer: 23 , std: 1.0043185949325562 layer: 24 , std: 0.9859704375267029 layer: 25 , std: 0.9940337538719177 layer: 26 , std: 1.0047379732131958 layer: 27 , std: 1.0038164854049683 layer: 28 , std: 1.0144379138946533 layer: 29 , std: 1.0297335386276245 layer: 30 , std: 1.0231270790100098 layer: 31 , std: 0.9947567582130432 layer: 32 , std: 1.0121735334396362 layer: 33 , std: 1.0102561712265015 layer: 34 , std: 1.0205620527267456 layer: 35 , std: 1.0590678453445435 layer: 36 , std: 1.0277358293533325 layer: 37 , std: 1.0321041345596313 layer: 38 , std: 1.0334043502807617 layer: 39 , std: 1.0470187664031982 layer: 40 , std: 1.0888501405715942 layer: 41 , std: 1.063532829284668 layer: 42 , std: 1.0635225772857666 layer: 43 , std: 1.0936106443405151 layer: 44 , std: 1.0897372961044312 layer: 45 , std: 1.0780189037322998 layer: 46 , std: 1.1132346391677856 layer: 47 , std: 1.1005138158798218 layer: 48 , std: 1.0610020160675049 layer: 49 , std: 1.114995002746582 layer: 50 , std: 1.107061743736267 layer: 51 , std: 1.1147115230560303 layer: 52 , std: 1.1051268577575684 layer: 53 , std: 1.0692596435546875 layer: 54 , std: 1.059423565864563 layer: 55 , std: 1.0318952798843384 layer: 56 , std: 1.0445512533187866 layer: 57 , std: 1.038772463798523 layer: 58 , std: 1.0729072093963623 layer: 59 , std: 1.0931061506271362 layer: 60 , std: 1.102836012840271 layer: 61 , std: 1.0710251331329346 layer: 62 , std: 1.0685100555419922 layer: 63 , std: 1.0235627889633179 layer: 64 , std: 1.0192655324935913 layer: 65 , std: 1.0483664274215698 layer: 66 , std: 1.033905267715454 layer: 67 , std: 1.0418909788131714 layer: 68 , std: 1.0399161577224731 layer: 69 , std: 1.0536786317825317 layer: 70 , std: 1.041662573814392 layer: 71 , std: 1.0555484294891357 layer: 72 , std: 1.0822663307189941 layer: 73 , std: 1.0788710117340088 layer: 74 , std: 1.1118624210357666 layer: 75 , std: 1.0804673433303833 layer: 76 , std: 1.0754098892211914 layer: 77 , std: 1.0847842693328857 layer: 78 , std: 1.0808136463165283 layer: 79 , std: 1.0306202173233032 layer: 80 , std: 1.0064393281936646 layer: 81 , std: 1.0131638050079346 layer: 82 , std: 1.023984670639038 layer: 83 , std: 1.005560040473938 layer: 84 , std: 0.9921131134033203 layer: 85 , std: 0.9612709879875183 layer: 86 , std: 0.957591712474823 layer: 87 , std: 0.952028751373291 layer: 88 , std: 0.9482743144035339 layer: 89 , std: 0.9498487114906311 layer: 90 , std: 0.9595613479614258 layer: 91 , std: 0.9428602457046509 layer: 92 , std: 0.9281052350997925 layer: 93 , std: 0.8957657814025879 layer: 94 , std: 0.9068138003349304 layer: 95 , std: 0.8488100171089172 layer: 96 , std: 0.8666995763778687 layer: 97 , std: 0.8959987759590149 layer: 98 , std: 0.8925248980522156 layer: 99 , std: 0.8857517242431641
可以看出是基本在1附近的,这样既不会梯度消失也不会梯度爆炸。
五、Xavier正态分布初始化 torch. nn. init. xavier_normal_( tensor: torch. Tensor, gain: float = 1.0 ) → torch. Tensor
功能:从这个正态分布N(0, std^2)中随机采样初始化(具体看介绍Xavier的内容),其中:
参数如下所示:
六、kaiming均匀分布初始化 torch. nn. init. kaiming_uniform_( tensor, a= 0 , mode= 'fan_in' , nonlinearity= 'leaky_relu' )
功能:从均匀分布 U(-bound, bound) 中随机采样初始化(具体看介绍kaiming的内容:Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015) ) 其中: