热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pythontrain函数_用Python从底层实现一个多层感知机

在上一篇文章中,我们从数学理论对多层感知机的反向传播进行了推导。南柯一梦宁沉沦:神经网络中反向传播算法数学推导​zhuanlan.zhihu.com这一
604d576de9a766adf35493ba9b0dc1dc.png

在上一篇文章中,我们从数学理论对多层感知机的反向传播进行了推导。

南柯一梦宁沉沦:神经网络中反向传播算法数学推导​zhuanlan.zhihu.com
40f556d258d3054d5754f45493e37d70.png

这一篇文章中我们将基于上一篇文章最后给出的算法使用Python语言来实现一个多层感知机。

完整代码以及代码的使用方法,可以光顾我的Github

ProfessorHuang/Python_LeNet_UnderlyingImplementation​github.com

MNIST数据准备

要进行训练,我们第一步需要先准备好训练数据,在这里我们使用经典的MNIST数据集。MNIST数据集的获取有多种方式,存储的格式也各不相同。在这里,我建议直接在MNIST的官网,即Yann Lecun的个人网站上获取。http://yann.lecun.com/exdb/mnist/index.html

我们可以在网站上下载得到四个压缩包,分别是训练图片,训练标签,测试图片和测试标签,解压之后获得四个文件的存储格式为idx。我们需要使用Python的struct库中unpack函数进行解析。然后使用numpy库将数据转换成numpy数组的形式便于我们之后处理。

图片数据解析为一个三维数组(也可称为张量),格式为图片数量×784×1.需要提前将图片数据从0-255的整数转换成0-1的浮点数。标签数据解析为一维数据,只存储了标签值的一个标量。我们需要将它也转换成一个三维数组,格式为标签数量×10×1,每个标签以one-hot格式存储,即一个10维列向量,正确标签为下标的数为1,其它的为0.

import

train_image维度为60000×784×1,train_label维度为60000×10×1

test_image维度为10000×784×1,test_label维度为10000×10×1


编写神经网络类

初始化神经网络

我们定义一个神经网络类NetWork,将涉及到的函数与神经网络各层参数封装在里面。

class

在初始化一个神经网络时,我们只需要以列表的形式提供神经网络各层的大小,神经网络各层的权重矩阵以及偏置项便会随机初始化并以二维numpy数组的进行存储,并以列表的形式按顺序存放在self.weights以及self.biases中,可以让该神经网络中其它方法使用与更新。

注意权重矩阵的初始化使用np.random.randn提供标准正态分布的随机数,它有正也有负。而我之前不小心使用了np.random.random提供的是0-1的浮点数,结果导致神经网络无法训练。

神经网络前向传播

神经网络前向传播很简单,取出权重矩阵和偏置项,通过矩阵乘法运算和矩阵相加运算,再经过激活函数即可根据前一层的输出得到当前层的输出,递归运算下去即可得到神经网络最终的输出。

def

矩阵的乘法用np.dot函数实现。我们使用的是sigmoid函数

作为激活函数。

def

神经网络反向传播

根据输入列向量x,前向传播出各层激活前的输出

和激活后的输出
为了之后计算delta误差以及损失函数对权重矩阵的导数。

将最后一层输出

与标签列向量y代入到损失函数对
的导数,求得最后一层的delta误差

利用公式

可以依次求出每层的delta误差,Hadmard积直接用*符号即可,表示逐元素相乘。

每求出一层的delta误差,便可以很快的带入公式,该层的偏置的导数与delta误差相等,该层权重矩阵的导数等于该层delta误差右乘上上一层激活后输出的转置。

def

backprop方法只是根据一副图像以及对应的标签求得神经网络参数的导数,而我们使用随机梯度下降法,需要使用一个batch的数据来更新数据。sigmoid_prime是对sigmoid函数的一阶导数:

def

我们使用方法update_mini _batch来调用backprop方法实现对一个batch的数据进行更新

def

update_mini_batch方法一次接收一个batch的训练图片和对应的训练标签,根据该batch数据求得的参数的平均导数,使用梯度下降法对神经网络中各层权重矩阵与偏置进行更新。

我们需要使用整个60000张训练数据来对神经网络进行训练,因此我们需要一个更高层的函数SGD,接收训练数据,并将训练数据分成一个个batch,再调用update_mini_batch方法对参数进行更新。

def

SGD接收的参数中,epochs代表训练轮数,将60000张数据全部训练一遍称为一个epoch。mini_batch_size表示batch大小,即一次使用多少张图片对参数进行更新。eta表示学习率。

验证神经网络准确率

我们在SGD方法中可以看见evaluate方法,它在每训练完一个epoch数据后,使用10000张测试数据来验证我们神经网络的准确率。

def

验证的方法很简单,依次从验证数据集中取出图片,经过神经网络前向传播,看最终预测值与图片的标签是否一致即可。evaluate方法返回10000张图片中预测正确的数量。


测试我们的神经网络

我们使用两行代码即可对我们的神经网络定义以及训练

# 训练神经网络

我们设置一个三层的神经网络,唯一的一个隐藏层只有30个神经元,可以加快我们的训练速度。我们调用SGD方法,训练30个epoch,batch大小为10,学习率为3.这些参数都是我们可以调整的,但相应地会取得不同的训练效果,不合适的参数有时候会导致训练无法正确进行。

239172ec9b1b3d765a6879b79f37afe1.png

完成30个epoch的训练,在我的电脑上大概只需要3分钟即可,而我们神经网络对MNIST验证集的预测正确率已经可以达到94.85%。bingo!


参考:

[1]刘建平Pinard:深度神经网络(DNN)反向传播算法(BP)

深度神经网络(DNN)反向传播算法(BP) - 刘建平Pinard - 博客园​www.cnblogs.com

[2] Neural Networks and Deep Learning by By Michael Nielsen

Neural networks and deep learning​neuralnetworksanddeeplearning.com
9a34b24ae60c47b24c4449f83e56e5cf.png

[3] 孤独暗星: MNIST手写数字数据集的读取,基于python3

https://blog.csdn.net/weixin_40522523/article/details/82823812​blog.csdn.net


推荐阅读
  • 本文介绍了UUID(通用唯一标识符)的概念及其在JavaScript中生成Java兼容UUID的代码实现与优化技巧。UUID是一个128位的唯一标识符,广泛应用于分布式系统中以确保唯一性。文章详细探讨了如何利用JavaScript生成符合Java标准的UUID,并提供了多种优化方法,以提高生成效率和兼容性。 ... [详细]
  • 通过使用 `pandas` 库中的 `scatter_matrix` 函数,可以有效地绘制出多个特征之间的两两关系。该函数不仅能够生成散点图矩阵,还能通过参数如 `frame`、`alpha`、`c`、`figsize` 和 `ax` 等进行自定义设置,以满足不同的可视化需求。此外,`diagonal` 参数允许用户选择对角线上的图表类型,例如直方图或密度图,从而提供更多的数据洞察。 ... [详细]
  • 2018 HDU 多校联合第五场 G题:Glad You Game(线段树优化解法)
    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6356在《Glad You Game》中,Steve 面临一个复杂的区间操作问题。该题可以通过线段树进行高效优化。具体来说,线段树能够快速处理区间更新和查询操作,从而大大提高了算法的效率。本文详细介绍了线段树的构建和维护方法,并给出了具体的代码实现,帮助读者更好地理解和应用这一数据结构。 ... [详细]
  • 为了优化用户体验,本文探讨了如何调整下拉菜单的宽度。通过合理设置宽度,可以提升界面的美观性和易用性。文章提供了具体的代码示例,帮助开发者实现这一目标。例如,可以通过 CSS 或 JavaScript 来动态调整下拉菜单的宽度,确保其在不同设备和屏幕尺寸上都能保持良好的显示效果。 ... [详细]
  • ButterKnife 是一款用于 Android 开发的注解库,主要用于简化视图和事件绑定。本文详细介绍了 ButterKnife 的基础用法,包括如何通过注解实现字段和方法的绑定,以及在实际项目中的应用示例。此外,文章还提到了截至 2016 年 4 月 29 日,ButterKnife 的最新版本为 8.0.1,为开发者提供了最新的功能和性能优化。 ... [详细]
  • OpenAI首席执行官Sam Altman展望:人工智能的未来发展方向与挑战
    OpenAI首席执行官Sam Altman展望:人工智能的未来发展方向与挑战 ... [详细]
  • 在List和Set集合中存储Object类型的数据元素 ... [详细]
  • 优化后的标题:深入探讨网关安全:将微服务升级为OAuth2资源服务器的最佳实践
    本文深入探讨了如何将微服务升级为OAuth2资源服务器,以订单服务为例,详细介绍了在POM文件中添加 `spring-cloud-starter-oauth2` 依赖,并配置Spring Security以实现对微服务的保护。通过这一过程,不仅增强了系统的安全性,还提高了资源访问的可控性和灵活性。文章还讨论了最佳实践,包括如何配置OAuth2客户端和资源服务器,以及如何处理常见的安全问题和错误。 ... [详细]
  • 本文详细介绍了批处理技术的基本概念及其在实际应用中的重要性。首先,对简单的批处理内部命令进行了概述,重点讲解了Echo命令的功能,包括如何打开或关闭回显功能以及显示消息。如果没有指定任何参数,Echo命令会显示当前的回显设置。此外,文章还探讨了批处理技术在自动化任务执行、系统管理等领域的广泛应用,为读者提供了丰富的实践案例和技术指导。 ... [详细]
  • 在 Vue 应用开发中,页面状态管理和跨页面数据传递是常见需求。本文将详细介绍 Vue Router 提供的两种有效方式,帮助开发者高效地实现页面间的数据交互与状态同步,同时分享一些最佳实践和注意事项。 ... [详细]
  • 使用 Vuex 管理表单状态:当输入框失去焦点时自动恢复初始值 ... [详细]
  • 在Django中提交表单时遇到值错误问题如何解决?
    在Django项目中,当用户提交包含多个选择目标的表单时,可能会遇到值错误问题。本文将探讨如何通过优化表单处理逻辑和验证机制来有效解决这一问题,确保表单数据的准确性和完整性。 ... [详细]
  • 每年,意甲、德甲、英超和西甲等各大足球联赛的赛程表都是球迷们关注的焦点。本文通过 Python 编程实现了一种生成赛程表的方法,该方法基于蛇形环算法。具体而言,将所有球队排列成两列的环形结构,左侧球队对阵右侧球队,首支队伍固定不动,其余队伍按顺时针方向循环移动,从而确保每场比赛不重复。此算法不仅高效,而且易于实现,为赛程安排提供了可靠的解决方案。 ... [详细]
  • `chkconfig` 命令主要用于管理和查询系统服务在不同运行级别中的启动状态。该命令不仅能够更新服务的启动配置,还能检查特定服务的当前状态。通过 `chkconfig`,管理员可以轻松地控制服务在系统启动时的行为,确保关键服务正常运行,同时禁用不必要的服务以提高系统性能和安全性。本文将详细介绍 `chkconfig` 的各项参数及其使用方法,帮助读者更好地理解和应用这一强大的系统管理工具。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
author-avatar
书友16941424_529
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有