热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

用TensorFlow构建基础的神经网络(一):MLP

源代码下载:http:pan.baidu.coms1kUAsk5L作者:XJTU_Ironboy时间:2017年8月开头语由于最近在学习

源代码下载:http://pan.baidu.com/s/1kUAsk5L
作者:XJTU_Ironboy
时间:2017年8月


开头语

  由于最近在学习Deep Learning方面的知识,并尝试着用Google近些年刚提出的TensorFlow框架来搭建各种经典的神经网络,如MLPLeNet-5AlexNetVGGGoogleNetResNetDenseNetSqueezeNet,所以在接下来的学习过程中我将逐个地介绍MLPLeNet-5AlexNet三个基础的神经网络的结构细节,因为其他的神经网络都可以说是在AlexNet上进行结构的修改而得到的,所以说如果学会了这三个基本网络的构建,那么其他结构的构建就是套路了。


一、MLP


1.单层感知机

  在我的理解中,Deep Learning的学习大多数都是从感知机(perceptron)开始的,因为它是神经网络的开山鼻祖,一种最最简单的神经网络结构。
  感知机是一种二分类模型,输入实例的特征向量,输出实例的±类别。由于这个模型过于简单,懒得废话,还是直接上图吧。
这里写图片描述

其中输入(Input)是一个m维的向量: [X1,X2,X3,...,Xi,...,Xm],权重(Weight)也是一个m维的向量: [W1,W2,W3,...,Wi,...,Wm],偏置项(Bias)是:[b],输出(Output)是一个标量输出:y,激活函数(Activation Function)是: 符号函数(signum:x大于0的时候为1;x等于0的时候为0,x小于0的时候为-1)。

2.多层感知机

  多层感知器(MLP,Multilayer Perceptron)是一种前馈人工神经网络模型,其将输入的多个数据集映射到单一的输出的数据集上。


这里写图片描述

  接下来讲一下如何用TensorFlow来实现这个简单的多层感知机(MLP):
1.数据库: MNIST数字手写体数据库
2.编程配置: Python3.5 + TensorFlow1.2.0
3.结构: 输入层+一个隐含层+输出层
输入图像大小: 28×28
     隐含层神经元个数: 300;
    输出神经元个数: 10;
     激活函数: ReLU
    优化器: Adagrad
    训练集上的batch size: 1000
    测试集上的batch size: 10000
  4.TensorFlow上实现
① 首先,导入tensorflow和MNIST数据集

import tensorflow as tf
# 导入MNIST数字手写体数据库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

② 开始定义MLP的整体结构

# 定义输入层、隐含层、输出层的神经元个数
in_units = 784
h1_units = 300
out_units = 10

③ 输入层,with tf.name_scope是TensorFlow中的命名空间,便于在tensorboard可视化整体的结构

# 定义输入层,keep_prob是dropout的比例
with tf.name_scope("input"):x = tf.placeholder(tf.float32,[None,in_units])y_= tf.placeholder(tf.float32,[None,out_units])
keep_prob = tf.placeholder(tf.float32)

④ 隐含层:权重、偏置的初始值都是正态分布的随机数

# 定义隐含层的权重、偏置、激活函数
with tf.name_scope("hidden_layer1"):with tf.name_scope("w1"): w1 = tf.Variable(tf.random_normal([in_units,h1_units],stddev = 0.1))tf.summary.histogram('Weight1',w1)with tf.name_scope("b1"): b1 = tf.Variable(tf.zeros([h1_units])) + 0.01tf.summary.histogram('biases1',b1)with tf.name_scope("w1_b1"):hidden1 = tf.nn.relu(tf.matmul(x,w1) + b1)tf.summary.histogram('output1',hidden1)

⑤ 输出层:权重、偏置的初始值都是正态分布的随机数

# 定义输出层的权重、偏置、激活函数
with tf.name_scope("output_layer"):with tf.name_scope("w2"): w2 = tf.Variable(tf.random_normal([h1_units,out_units],stddev = 0.1))tf.summary.histogram('Weight2',w2)with tf.name_scope("b2"): b2 = tf.Variable(tf.zeros([out_units]))tf.summary.histogram('biases2',b2)with tf.name_scope("w2_b2"):hidden1_drop = tf.nn.dropout(hidden1,keep_prob)
with tf.name_scope("output"):y = tf.nn.softmax(tf.matmul(hidden1_drop,w2)+b2)tf.summary.histogram('output',y)

⑥ 损失函数、准确率、优化器

# 定义损失函数———交叉熵
with tf.name_scope("loss"):cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices = [1]))tf.summary.scalar('cross_entropy', cross_entropy)
# 计算准确率
with tf.name_scope("accuracy"):correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# 定义优化器——Adagrad,和学习率:0.3
with tf.name_scope("train"): train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy)

⑦ 框架搭好了,正式开始计算

# 初始化所有的变量
init = tf.global_variables_initializer()
# 开始导入数据,正式计算,迭代3000步,训练时batch size=100
with tf.Session() as sess:sess.run(init)merge = tf.summary.merge_all()writer = tf.summary.FileWriter("log",sess.graph)for i in range(3000):batch_xs,batch_ys = mnist.train.next_batch(1000)sess.run(train_step,feed_dict = {x:batch_xs,y_:batch_ys,keep_prob:0.75})loss_run = sess.run(cross_entropy,feed_dict = {x:batch_xs,y_:batch_ys,keep_prob:0.75})accuracy_run = sess.run(accuracy,feed_dict = {x:batch_xs,y_:batch_ys,keep_prob:0.75})print('after %d steps training steps,the loss is %g and the accuracy is %g'%(i,loss_run,accuracy_run))result = sess.run(merge,feed_dict = {x:batch_xs,y_:batch_ys,keep_prob:1})writer.add_summary(result,i)# 训练完后直接加载测试集数据,进行测试if i == 2999:loss_run = sess.run(cross_entropy,feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1})accuracy_run = sess.run(accuracy,feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1})print('the loss in test dataset is %g and the accuracy in test dataset is %g'%(loss_run,accuracy_run))

  5 .运行结果:



这里写图片描述

这里写图片描述

这里写图片描述


  
测试集上的准确度:98.00%

  6.
TensorBoard可视化:

① 在MLP.py程序的所在文件夹下打开cmd窗口(针对Windows)


方法一:打开cmd,然后用“cd + 路径”的方式找到该位置


方法二:定位到MLP.py所在文件的位置,点击左上角的“文件”,然后点击“打开命令提示符”

② 输入:
tensorboard - -logdir=log ,回车


这里写图片描述

③ 复制上面的地址到浏览器,如我这上面的地址是:http://Ironboy:6006
④ 可视化结果:



这里写图片描述
这里写图片描述
这里写图片描述


推荐阅读
  • 在Hive中合理配置Map和Reduce任务的数量对于优化不同场景下的性能至关重要。本文探讨了如何控制Hive任务中的Map数量,分析了当输入数据超过128MB时是否会自动拆分,以及Map数量是否越多越好的问题。通过实际案例和实验数据,本文提供了具体的配置建议,帮助用户在不同场景下实现最佳性能。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 单片微机原理P3:80C51外部拓展系统
      外部拓展其实是个相对来说很好玩的章节,可以真正开始用单片机写程序了,比较重要的是外部存储器拓展,81C55拓展,矩阵键盘,动态显示,DAC和ADC。0.IO接口电路概念与存 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • 本文对比了杜甫《喜晴》的两种英文翻译版本:a. Pleased with Sunny Weather 和 b. Rejoicing in Clearing Weather。a 版由 alexcwlin 翻译并经 Adam Lam 编辑,b 版则由哈佛大学的宇文所安教授 (Prof. Stephen Owen) 翻译。 ... [详细]
  • 深入解析 Lifecycle 的实现原理
    本文将详细介绍 Android Jetpack 中 Lifecycle 组件的实现原理,帮助开发者更好地理解和使用 Lifecycle,避免常见的内存泄漏问题。 ... [详细]
  • 解决Bootstrap DataTable Ajax请求重复问题
    在最近的一个项目中,我们使用了JQuery DataTable进行数据展示,虽然使用起来非常方便,但在测试过程中发现了一个问题:当查询条件改变时,有时查询结果的数据不正确。通过FireBug调试发现,点击搜索按钮时,会发送两次Ajax请求,一次是原条件的请求,一次是新条件的请求。 ... [详细]
  • poj 3352 Road Construction ... [详细]
  • 本报告对2018年湘潭大学程序设计竞赛在牛客网上的时间数据进行了详细分析。通过统计参赛者在各个时间段的活跃情况,揭示了比赛期间的编程频率和时间分布特点。此外,报告还探讨了选手在准备过程中面临的挑战,如保持编程手感、学习逆向工程和PWN技术,以及熟悉Linux环境等。这些发现为未来的竞赛组织和培训提供了 valuable 的参考。 ... [详细]
  • 本文详细介绍了在MySQL中如何高效利用EXPLAIN命令进行查询优化。通过实例解析和步骤说明,文章旨在帮助读者深入理解EXPLAIN命令的工作原理及其在性能调优中的应用,内容通俗易懂且结构清晰,适合各水平的数据库管理员和技术人员参考学习。 ... [详细]
  • 2018 HDU 多校联合第五场 G题:Glad You Game(线段树优化解法)
    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6356在《Glad You Game》中,Steve 面临一个复杂的区间操作问题。该题可以通过线段树进行高效优化。具体来说,线段树能够快速处理区间更新和查询操作,从而大大提高了算法的效率。本文详细介绍了线段树的构建和维护方法,并给出了具体的代码实现,帮助读者更好地理解和应用这一数据结构。 ... [详细]
  • 深入理解排序算法:集合 1(编程语言中的高效排序工具) ... [详细]
  • MySQL:不仅仅是数据库那么简单
    MySQL不仅是一款高效、可靠的数据库管理系统,它还具备丰富的功能和扩展性,支持多种存储引擎,适用于各种应用场景。从简单的网站开发到复杂的企业级应用,MySQL都能提供强大的数据管理和优化能力,满足不同用户的需求。其开源特性也促进了社区的活跃发展,为技术进步提供了持续动力。 ... [详细]
  • Java 8 引入了 Stream API,这一新特性极大地增强了集合数据的处理能力。通过 Stream API,开发者可以更加高效、简洁地进行集合数据的遍历、过滤和转换操作。本文将详细解析 Stream API 的核心概念和常见用法,帮助读者更好地理解和应用这一强大的工具。 ... [详细]
author-avatar
邻居小明
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有