热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

BatchNorm推理阶段和Conv合并

一、BN层作用批量归一化(BatchNormalization,BN)在深度学习中常放在卷积层之后,BN层有以下优点:减少了人为选择参数。在某些情况下可以取消dropout和L2正

一、BN层作用

批量归一化(Batch Normalization,BN)在深度学习中常放在卷积层之后,BN层有以下优点:



  • 减少了人为选择参数。在某些情况下可以取消 dropout 和 L2 正则项参数,或者采取更小的 L2 正则项约束参数;

  • 减少了对学习率的要求。现在我们可以使用初始很大的学习率或者选择了较小的学习率,算法也能够快速训练收敛;

  • 可以不再使用局部响应归一化。BN 本身就是归一化网络(局部响应归一化在 AlexNet 网络中存在);

  • 破坏原来的数据分布,一定程度上缓解过拟合(防止每批训练中某一个样本经常被挑选到,文献说这个可以提高 1% 的精度);

  • 减少梯度消失,加快收敛速度,提高训练精度


二、 BN层算法流程

下面给出的是 BN 算法在训练时的过程

输入:上一层输出结果 $ X = {x_1, x_2, ..., x_m} $,学习参数 $ \gamma, \beta $。

算法流程



  1. 计算上一层输出数据的均值

\[\mu_{\beta} = \frac{1}{m} \sum_{i=1}^m(x_i)

\]

其中,$ m $ 是此次训练样本 batch 的大小。



  1. 计算上一层输出数据的标准差

\[\sigma_{\beta}^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\beta})^2

\]



  1. 归一化处理,得到

\[\hat x_i = \frac{x_i - \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon}

\]

其中 $ \epsilon $ 是为了避免分母为 0 而加进去的接近于 0 的很小值



  1. 重构,对经过上面归一化处理得到的数据进行重构,得到

\[y_i = \gamma \hat x_i + \beta

\]

其中,$ \gamma, \beta $ 为可学习参数。

注:上述是 BN 训练时的过程,但是当在测试阶段时,往往只是输入一个样本,没有所谓的均值 $ \mu_{\beta} $ 和标准差 $ \sigma_{\beta}^2 $。此时,均值 $ \mu_{\beta} $ 是计算所有 batch $ \mu_{\beta} $ 值的平均值得到,标准差 $ \sigma_{\beta}^2 $ 采用每个batch $ \sigma_{\beta}^2 $ 的无偏估计得到


三、推理阶段合并BN和conv的原理

如果BN层在卷积层Conv之后,那卷积和BN层可以合并成如下式子。

卷积层


\[Z = W X + B

\]

BN层


\[Y = \frac{Z - \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta

\]

合并上面两个式子可得:


\[Y = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon} X + (\frac{B - \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta)

\]


\[W^{'} = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon}

\]


\[B^{'} = \frac{B - \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta

\]

可得


\[Y = W^{'} X + B^{'}

\]

因此只需要更新卷积层的权值和偏置就可以达到合并卷积和BN层的效果。


三、code

import torch
import torch.nn as nn
import torchvision as tv
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse(conv, bn):
w = conv.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * beta + gamma
fused_cOnv= nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_conv_and_bn(conv, bn):
# init
fused_cOnv= torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# # prepare filters
w_cOnv= conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused_conv.weight = nn.Parameter(torch.mm(w_bn, w_conv).view(fused_conv.weight.size()))
# # prepare spatial bias
if conv.bias is not None:
b_cOnv= conv.bias
else:
b_cOnv= torch.zeros(conv.weight.size(0))
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fused_conv.bias = nn.Parameter(torch.matmul(w_bn, b_conv) + b_bn)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
print("***********")
print(children)
print("***********")
c = None
cn = None
for name, child in children:
if isinstance(child, nn.BatchNorm2d):
# bc = fuse(c, child)
bc = fuse_conv_and_bn(c, child)
m._modules[cn] = bc
m._modules[name] = DummyModule()
print("==> name: ", name)
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
else:
fuse_module(child)
def test_net(m):
p = torch.randn([1, 3, 224, 224])
import time
s = time.time()
o_output = m(p)
print("Original time: ", time.time() - s)
fuse_module(m)
s = time.time()
f_output = m(p)
print("Fused time: ", time.time() - s)
print("Max abs diff: ", (o_output - f_output).abs().max().item())
assert(o_output.argmax() == f_output.argmax())
# print(o_output[0][0].item(), f_output[0][0].item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
def test_layer():
p = torch.randn([1, 3, 112, 112])
conv1 = m.conv1
bn1 = m.bn1
o_output = bn1(conv1(p))
fusion = fuse(conv1, bn1)
f_output = fusion(p)
print(o_output[0][0][0][0].item())
print(f_output[0][0][0][0].item())
print("Max abs diff: ", (o_output - f_output).abs().max().item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
if __name__ == "__main__":
m = tv.models.resnet18(True)
m.eval()
print("Layer level test: ")
test_layer()
print("============================")
print("Module level test: ")
test_net(m)

参考链接

https://blog.csdn.net/wfei101/article/details/78635557

https://zhuanlan.zhihu.com/p/49329030

https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html?highlight=batchnorm

https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html?highlight=batchnorm



推荐阅读
author-avatar
Fxnananana
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有