热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pythonbatchnormalization_BatchNormalization原理与python实现

为了保证深度神经网络训练过程的稳定性,经常需要细心的选择初始化方式,并且选择较小的学习率数值,这无疑增加了任务的复杂性。为此,

为了保证深度神经网络训练过程的稳定性,经常需要细心的选择初始化方式,并且选择较小的学习率数值,这无疑增加了任务的复杂性。为此,Google团队提出了Batch Normalization【1】方法(简称BN)用于帮助网络更好的训练。

1、理论分析

BN计算的第一步是对每一层进行独立的归一化:

其中k表示第k维特征,E表示求期望,Var表示求方差。这种归一化操作可能会改变这层的表示,所以作者提出了“identity transform”如下:

其中r(k)和B(k)是两个需要通过网络训练学习的参数,这两个参数随着迭代进行动态的更新,当下述情况时:

此时输出和输入就是完全相同的,都是x(k)。注释:identity transform,个人理解其目的是使得输出和输入相同,即不改变层的表示

综上,实现BN需要求的:均值、方差、参数beta、参数gamma。对应的算法流程如下,需要注意的是,Normalization的计算是对“每个特征”分别进行的:

2、python实现

2.1 数据和各个变量的含义

在使用类似pytorch的框架时发现,由于每个mini-batch中的数据是不同的,所以需要统计整个数据集中的均值和方差需要动态的追踪各个mini-batch(进行实现的时候参考了pytorch的BN文档【7】,其源码很多部分已经迁移到了C++所以分析起来比较困难),动态统计方式如下:

这里的“x”不是论文中的“x”,这里表示running_mean和runnning_var的更新权重。首先定义将要使用的数据:

data = np.array([[1, 2],

[1, 3],

[1, 4]]).astype(np.float32)

然后采用pytorch提供的BatchNorm1d模块进行测试,确保自己代码获得的结果能够和pytorch一致:

bn_torch = nn.BatchNorm1d(num_features=2)

data_torch = torch.from_numpy(data)

bn_output_torch = bn_torch(data_torch)

print(bn_output_torch)

得到的输出如下:

tensor([[ 0.0000, -1.1526],

[ 0.0000, 0.0000],

[ 0.0000, 1.1526]], grad_fn=)

2.2 前向传播实现

因为BN计算过程中需要保存running_mean和running_var,以及更新动量momentum和防止数值计算错误的eps,所以需要设计为类,并用实例属性来保存这些值,下面是初始化方法:

class MyBN:

def __init__(self, momentum, eps, num_features):

"""初始化参数值:param momentum: 追踪样本整体均值和方差的动量:param eps: 防止数值计算错误:param num_features: 特征数量"""

# 对每个batch的mean和var进行追踪统计

self._running_mean = 0

self._running_var = 1

# 更新self._running_xxx时的动量

self._momentum = momentum

# 防止分母计算为0

self._eps = eps

# 对应论文中需要更新的beta和gamma,采用pytorch文档中的初始化值

self._beta = np.zeros(shape=(num_features, ))

self._gamma = np.ones(shape=(num_features, ))

初始化中_beta和_gamma对应于BN中需要学习的参数,分别初始化为0和1,接下来就是前向传播的实现:

def batch_norm(self, x):

"""BN向传播:param x: 数据:return: BN输出"""

x_mean = x.mean(axis=0)

x_var = x.var(axis=0)

# 对应running_mean的更新公式

self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_mean

self._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var

# 对应论文中计算BN的公式

x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)

y = self._gamma*x_hat + self._beta

return y

由于pytorch中的BatchNorm中beta和gamma初始化并不是0和1,为了保证初始化值一样,将自己定义的类的beta和gamm替换为torch初始化的值,进行如下测试:

my_bn = MyBN(momentum=0.01, eps=0.001, num_features=2)

my_bn._beta = bn_torch.bias.detach().numpy()

my_bn._gamma = bn_torch.weight.detach().numpy()

bn_output = my_bn.batch_norm(data, )

print(bn_output)

得到的结果和torch的结果一致:

[[ 0. -1.1517622]

[ 0. 0. ]

[ 0. 1.1517622]]

参考:



推荐阅读
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 不用蘑菇,不拾金币,我通过强化学习成功通关29关马里奥,创造全新纪录
    《超级马里奥兄弟》由任天堂于1985年首次发布,是一款经典的横版过关游戏,至今已在多个平台上售出超过5亿套。该游戏不仅勾起了许多玩家的童年回忆,也成为强化学习领域的热门研究对象。近日,通过先进的强化学习技术,研究人员成功让AI通关了29关,创造了新的纪录。这一成就不仅展示了强化学习在游戏领域的潜力,也为未来的人工智能应用提供了宝贵的经验。 ... [详细]
  • 探索聚类分析中的K-Means与DBSCAN算法及其应用
    聚类分析是一种用于解决样本或特征分类问题的统计分析方法,也是数据挖掘领域的重要算法之一。本文主要探讨了K-Means和DBSCAN两种聚类算法的原理及其应用场景。K-Means算法通过迭代优化簇中心来实现数据点的划分,适用于球形分布的数据集;而DBSCAN算法则基于密度进行聚类,能够有效识别任意形状的簇,并且对噪声数据具有较好的鲁棒性。通过对这两种算法的对比分析,本文旨在为实际应用中选择合适的聚类方法提供参考。 ... [详细]
  • 从2019年AI顶级会议最佳论文,探索深度学习的理论根基与前沿进展 ... [详细]
  • 运用Isotonic回归算法解决鸢尾花数据集中的回归挑战
    本文探讨了利用Isotonic回归算法解决鸢尾花数据集中的回归问题。首先介绍了Isotonic回归的基本原理及其在保持单调性方面的优势,并通过具体示例说明其应用方法。随后详细描述了鸢尾花数据集的特征和获取途径,最后展示了如何将Isotonic回归应用于该数据集,以实现更准确的预测结果。 ... [详细]
  • [TensorFlow系列3]:初学者是选择Tensorflow2.x还是1.x? 2.x与1.x的主要区别?
    作者主页(文火冰糖的硅基工坊):https:blog.csdn.netHiWangWenBing本文网址:https:blog.csdn.netHiW ... [详细]
  • CBAM:卷积块注意模块
    CBAM:ConvolutionalBlockAttentionModule论文地址:https:arxiv.orgabs1807.06521简介:我们提出了 ... [详细]
  • 独家解析:深度学习泛化理论的破解之道与应用前景
    本文深入探讨了深度学习泛化理论的关键问题,通过分析现有研究和实践经验,揭示了泛化性能背后的核心机制。文章详细解析了泛化能力的影响因素,并提出了改进模型泛化性能的有效策略。此外,还展望了这些理论在实际应用中的广阔前景,为未来的研究和开发提供了宝贵的参考。 ... [详细]
  • 浅析卷积码的应用及其优势:探讨卷积编码在通信系统中的关键作用与特性
    本文详细介绍了卷积编码的基本原理,并深入分析了其在通信系统中的应用及其显著优势。卷积编码通过在编码过程中引入冗余信息,有效提高了数据传输的可靠性和抗干扰能力,成为现代通信系统中不可或缺的关键技术。文章还探讨了卷积编码在不同场景下的具体实现方法及其性能特点。 ... [详细]
  • Cosmos生态系统为何迅速崛起,波卡作为跨链巨头应如何应对挑战?
    Cosmos生态系统为何迅速崛起,波卡作为跨链巨头应如何应对挑战? ... [详细]
  • 创建一个水平滚动的表格视图
    本文介绍了如何创建一个水平滚动的表格视图,通过使用 `UITableView` 的变换属性 `transform` 和 `CGAffineTransformMakeRotation` 方法,实现视图的水平滚动效果。此外,还详细探讨了相关布局调整和性能优化技巧,确保在不同设备上都能获得流畅的用户体验。 ... [详细]
  • 本文探讨了BERT模型在自然语言处理领域的应用与实践。详细介绍了Transformers库(曾用名pytorch-transformers和pytorch-pretrained-bert)的使用方法,涵盖了从模型加载到微调的各个环节。此外,还分析了BERT在文本分类、情感分析和命名实体识别等任务中的性能表现,并讨论了其在实际项目中的优势和局限性。 ... [详细]
  • 基于OpenCV的图像拼接技术实践与示例代码解析
    图像拼接技术在全景摄影中具有广泛应用,如手机全景拍摄功能,通过将多张照片根据其关联信息合成为一张完整图像。本文详细探讨了使用Python和OpenCV库实现图像拼接的具体方法,并提供了示例代码解析,帮助读者深入理解该技术的实现过程。 ... [详细]
  • 本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。 ... [详细]
  • 利用 Python 中的 Altair 库实现数据抖动的水平剥离分析 ... [详细]
author-avatar
phpxiaohui
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有