热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

cnn图像识别python代码_pythontensorflow基于cnn实现手写数字识别

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下#-*-coding:utf-8-*-importtensorflowastffromte

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session

# 如果不使用交互式session,则在启动session前必须

# 构建整个计算图,才能启动该计算图

sess = tf.InteractiveSession()

"""构建计算图"""

# 通过占位符来为输入图像和目标输出类别创建节点

# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误

x = tf.placeholder("float", shape=[None, 784]) # 原始输入

y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,

# 我们定义两个函数用于初始化

def weight_variable(shape):

# 截尾正态分布,stddev是正态分布的标准偏差

initial = tf.truncated_normal(shape=shape, stddev=0.1)

return tf.Variable(initial)

def bias_variable(shape):

initial = tf.constant(0.1, shape=shape)

return tf.Variable(initial)

# 卷积核池化,步长为1,0边距

def conv2d(x, W):

return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):

return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],

strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""

# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积

# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数

W_conv1 = weight_variable([5, 5, 1, 32])

# 每一个输出通道都有一个偏置量

b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高

# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)

x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))

# 第一层卷积后的池化结果

h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""

W_conv2 = weight_variable([5, 5, 32, 64])

b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""

# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层

W_fc1 = weight_variable([7*7*64, 1024])

b_fc1 = bias_variable([1024])

# 将最后的池化层输出张量reshape成一维向量

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])

# 全连接层的输出

h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""

# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率

# 在训练的过程中启用dropout,在测试过程中关闭dropout

keep_prob = tf.placeholder("float")

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""

W_fc2 = weight_variable([1024, 10])

b_fc2 = bias_variable([10])

# 模型预测输出

y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失

cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降

train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List

correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))

# 将布尔值转化成浮点数,取平均值作为精确度

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用

sess.run(tf.global_variables_initializer())

# 迭代优化模型

for i in range(2000):

# 每次取50个样本进行训练

batch = mnist.train.next_batch(50)

if i%100 == 0:

train_accuracy = accuracy.eval(feed_dict={

x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout

print("step %d, training accuracy %g" % (i, train_accuracy))

train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})

print("test accuracy %g" % accuracy.eval(feed_dict={

x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

本文标题: python tensorflow基于cnn实现手写数字识别

本文地址: http://www.cppcns.com/jiaoben/python/216273.html



推荐阅读
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • 延迟注入工具(python)的SQL脚本
    本文介绍了一个延迟注入工具(python)的SQL脚本,包括使用urllib2、time、socket、threading、requests等模块实现延迟注入的方法。该工具可以通过构造特定的URL来进行注入测试,并通过延迟时间来判断注入是否成功。 ... [详细]
  • 树莓派语音控制的配置方法和步骤
    本文介绍了在树莓派上实现语音控制的配置方法和步骤。首先感谢博主Eoman的帮助,文章参考了他的内容。树莓派的配置需要通过sudo raspi-config进行,然后使用Eoman的控制方法,即安装wiringPi库并编写控制引脚的脚本。具体的安装步骤和脚本编写方法在文章中详细介绍。 ... [详细]
  • 本文介绍了使用Python解析C语言结构体的方法,包括定义基本类型和结构体类型的字典,并提供了一个示例代码,展示了如何解析C语言结构体。 ... [详细]
  • 本文介绍了使用Python编写购物程序的实现步骤和代码示例。程序启动后,用户需要输入工资,并打印商品列表。用户可以根据商品编号选择购买商品,程序会检测余额是否充足,如果充足则直接扣款,否则提醒用户。用户可以随时退出程序,在退出时打印已购买商品的数量和余额。附带了完整的代码示例。 ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
author-avatar
迦迦奥特曼_897
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有