热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow学习(1.CNN简单实现MNIST)

先看代码#tf可以认为是全局变量,从该变量为类,从中取input_data变量importtensorflow.examples.tutorials

先看代码

#tf可以认为是全局变量,从该变量为类,从中取input_data变量
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
#读取数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
"""
#softmax方法进行训练
#这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])#定义参数变量
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,W)+b)#评价函数
cross_entropy=-tf.reduce_sum(y_*tf.log(y))
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)#数据读取部分
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#run第一个参数是fetch,可以是tensor也可以是Operation,第二个feed_dict是替换tensor的值sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})print(batch_xs,batch_ys,i)correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))"""#这里用CNN方法进行训练
#函数定义部分
def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#随机权重赋值,不过truncated_normal代表如果是2倍标准差之外的结果重新选取该值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置项return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示输出补边,这里输出与输入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范围的大小,stride为扫描步长# 这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])
x_image=tf.reshape(x,[-1,28,28,1])#-1表示自动计算该维度
#建立第一层
W_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
#第二层
W_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)#第三层,而且这里是全连接层
W_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
#dropout,注意这里也是有一个输入参数的,和x以及y一样
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 评价函数
cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)#数据读取部分
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#这里貌似是代表读取50张图像数据#run第一个参数是fetch,可以是tensor也可以是Operation,第二个feed_dict是替换tensor的值'''if i % 10 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})print("step:%d,accuracy:%g" % (i, train_accuracy))'''sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})#sess.run第一个参数是想要运行的位置,一般有train,accuracy,initdeng#第二个参数feed_dict,一般是输入参数,该代码里有x,y以及drop的参数if i%20==0 :print(i)print("train accuracy:%g"%sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5}))
print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))

运行结果:

 

说明一下代码结构:

1.数据的读取(读取直接用MNIST,MNIST的数据结构x是一维的图像,size为(1,28*28),y是一维向量,size为(1,10),只要将数据的读取部分单独取出,就可以有比较清晰的代码了)

2.结构的建立(一般利用函数去定义池化等操作,这样能够有比较清晰的代码结构)

3.结构保存及测试(训练完数据后需要存储网络结构,该部分本文没有说明,正在学习)

本文作者也只是刚刚接触,针对实际项目会关注的一些问题,进行注释说明,如有错误,请指出说明

http://www.jeyzhang.com/tensorflow-learning-notes-2.html本链接的教程也比较清晰

 


推荐阅读
  • springMVC JRS303验证 ... [详细]
  • 本文详细探讨了在微服务架构中,使用Feign进行远程调用时出现的请求头丢失问题,并提供了具体的解决方案。重点讨论了单线程和异步调用两种场景下的处理方法。 ... [详细]
  • 软件工程课堂测试2
    要做一个简单的保存网页界面,首先用jsp写出保存界面,本次界面比较简单,首先是三个提示语,后面是三个输入框,然 ... [详细]
  • iOS 开发技巧:TabBarController 自定义与本地通知设置
    本文介绍了如何在 iOS 中自定义 TabBarController 的背景颜色和选中项的颜色,以及如何使用本地通知设置应用程序图标上的提醒个数。通过这些技巧,可以提升应用的用户体验。 ... [详细]
  • 烤鸭|本文_Spring之Bean的生命周期详解
    烤鸭|本文_Spring之Bean的生命周期详解 ... [详细]
  • 深入浅出TensorFlow数据读写机制
    本文详细介绍TensorFlow中的数据读写操作,包括TFRecord文件的创建与读取,以及数据集(dataset)的相关概念和使用方法。 ... [详细]
  • 本文将介绍如何利用Python爬虫技术抓取国内主流在线学习平台的数据,并以51CTO学院为例,进行详细的技术解析和实践操作。 ... [详细]
  • 本文探讨了如何利用HTML5和JavaScript在浏览器中进行本地文件的读取和写入操作,并介绍了获取本地文件路径的方法。HTML5提供了一系列API,使得这些操作变得更加简便和安全。 ... [详细]
  • 本文详细介绍了Java中实现异步调用的多种方式,包括线程创建、Future接口、CompletableFuture类以及Spring框架的@Async注解。通过代码示例和深入解析,帮助读者理解并掌握这些技术。 ... [详细]
  • 一个登陆界面
    预览截图html部分123456789101112用户登入1314邮箱名称邮箱为空15密码密码为空16登 ... [详细]
  • 本文深入探讨了 Delphi 中类对象成员的核心概念,包括 System 单元的基础知识、TObject 类的定义及其方法、TClass 的作用以及对象的消息处理机制。文章不仅解释了这些概念的基本原理,还提供了丰富的补充和专业解答,帮助读者全面理解 Delphi 的面向对象编程。 ... [详细]
  • 本文介绍了如何使用JFreeChart库创建一个美观且功能丰富的环形图。通过设置主题、字体和颜色等属性,可以生成符合特定需求的图表。 ... [详细]
  • 理解与应用:独热编码(One-Hot Encoding)
    本文详细介绍了独热编码(One-Hot Encoding)与哑变量编码(Dummy Encoding)两种方法,用于将分类变量转换为数值形式,以便于机器学习算法处理。文章不仅解释了这两种编码方式的基本原理,还探讨了它们在实际应用中的差异及选择依据。 ... [详细]
  • 本文详细介绍了JSP(Java Server Pages)的九大内置对象及其功能,探讨了JSP与Servlet之间的关系及差异,并提供了实际编码示例。此外,还讨论了网页开发中常见的编码转换问题以及JSP的两种页面跳转方式。 ... [详细]
  • 本文将指导如何在JFinal框架中快速搭建一个简易的登录系统,包括环境配置、数据库设计、项目结构规划及核心代码实现等环节。 ... [详细]
author-avatar
爱中华爱美丽
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有