热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow学习之MNIST数字集识别

Tensorflow学习日记之MNIST数字集识别Tensorflow学习之MNIST数字集识别(底部附上我的全部代码)前一阵突然被Tensorflow吸引了,正如Google发布

Tensorflow学习日记之MNIST数字集识别

Tensorflow学习之MNIST数字集识别

(底部附上我的全部代码)

前一阵突然被Tensorflow吸引了,正如Google发布人工智能系统TensorFlow文档中所说:
你正在阅读的项目可能会比 Android 系统更加深远地影响着世界!
没赶上Android开发的快车,当然不会放过Tensorflow这艘大轮船。
言归正传,不管学什么,都是有步骤地。

  1. 肯定是 What is it? 解决方法各种百度google。

  2. 看一遍Tensorflow文档,毕竟是官方的东西,看一遍没坏处。

  3. 在经历第二步之后会发现在数学算法方面有所缺失,理应补一下这个缺口,推荐回顾一下线性代数和概率论。

  4. 搭环境 这个之前的博客讲过。

  5. MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。跑一下
    Tensorflow之MNIST 数据集测试
    (底部附上我的全部代码)
    首先了解一下什么是MNIST
    MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:
    完成这个测试就好比编程入门有Hello World

参考《Tensorflow官方文档》
首先从官方网站下载数据集
数据集下载网址

下载下来的数据集被分成两部分:60000行的训练数据集( mnist.train )和10000行的测试数据集( mnist.test )。
这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模
型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。
数据集里面每一张图片包含28像素X28像素,每张照片对应一个标签。

在MNIST训练数据集中, mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来
索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个
像素的强度值,值介于0和1之间。

代码部分-------------------------------------------------------------
(这个代码是使用了Tensorflow最简单的一个模型,识别率91%左右,文章末尾附上另一个代码,使用了多层卷积网络,识别率97%左右)


#  import cv2
#  Tensorflow已经包含了mnist案例的数据
# 利用Tensorflow 对 MNIST 进行读取和格式转换
from tensorflow.examples.tutorials.mnist import input_data

#将下载的MNIST数据放在MNIST_data文件下
# input_data.read_data_sets()函数可以自动检测指定目录下是否存在MNIST数据,如果存在,就不会下载了
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# 笔记:Tensorflow依赖于一个高效的C++后端来进行计算。与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。
# 导入Tensorflow
import tensorflow as tf

# 运行TensorFlow的InteractiveSession
sess = tf.InteractiveSession


# # 通过opencv打印一张照片和对应标签测试一下
# print(mnist.train.images.shape, mnist.train.labels.shape)
# image = mnist.train.images[4,:]
# #将图像数据还原成28*28的分辨率
# image = image.reshape(28,28)
# #打印对应的标签
# print(mnist.train.labels[4])
# cv2.imshow('uu', image)
# cv2.waitKey(0)
# 创建一个占位符,通过操作符号变量来描述这些可交互的操作单元
# None 代表图片数量未知,784=28*28


# 构建Softmax 回归模型
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# x = tf.placeholder("float", [None, 784])

# 通过Variable的方式加入权重值和偏置量,Variable代表一个可修改的张量
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))


# sess.run(tf.global_variables_initializer())

# 实现我们的模型
y = tf.nn.softmax(tf.matmul(x, W)+b)

#为训练过程指定最小化误差用的损失函数
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

# -----------------------------------------------------
# 构建CNN模型,对图片运算来说,CNN模型算比较优秀的模型
# 训练模型
# 用最速下降法让交叉熵下降,步长为0.01
# 作用:往计算图上添加一个新操作,其中包括计算梯度,计算每个参数的步长变化,并且计算出新的参数值
# 往计算图上添加一个新操作,其中包括计算梯度,计算每个参数的步长变化,并且计算出新的参数值
# 每一步迭代,我们都会加载50个训练样本,然后执行一次train_step
# 并通过feed_dict将x 和 y_张量占位符用训练训练数据替代。
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 变量需要通过seesion初始化后,才能在session中使用。
# init = tf.initialize_all_variables()

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    batch = mnist.train.next_batch(50)
    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
    # train_step.run(feed_dict={x: batch[0], y_: batch[1]})

#------------------------------------------------------

# 评估模型
# 用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# 返回一个布尔数组

# 将布尔值转换为浮点数来代表对、错,然后取平均值。
# 例如:[True, False, True, True]变为[1,0,1,1],计算出平均值为0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 计算出在测试数据上的准确率
# ------------------------------------------------------
print("测试数据集正确率:")
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

Tensorflow学习之MNIST数字集识别
这是一个分界线-------------------------------------------------------------
下面是多层卷积网络版的

# coding:utf-8

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data",one_hot=True)


input = tf.placeholder(tf.float32,[None,784])
input_image = tf.reshape(input,[-1,28,28,1])

y = tf.placeholder(tf.float32,[None,10])

# input 代表输入,filter 代表卷积核
def conv2d(input,filter):
    return tf.nn.conv2d(input,filter,strides=[1,1,1,1],padding='SAME')
# 池化层
def max_pool(input):
    return tf.nn.max_pool(input,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

# 初始化卷积核或者是权重数组的值
def weight_variable(shape):
    initial = tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)

# 初始化bias的值
def bias_variable(shape):
    return tf.Variable(tf.zeros(shape))

#[filter_height, filter_width, in_channels, out_channels]
#定义了卷积核
filter = [3,3,1,32]

filter_conv1 = weight_variable(filter)
b_conv1 = bias_variable([32])
# 创建卷积层,进行卷积操作,并通过Relu**,然后池化
h_conv1 = tf.nn.relu(conv2d(input_image,filter_conv1)+b_conv1)
h_pool1 = max_pool(h_conv1)

h_flat = tf.reshape(h_pool1,[-1,14*14*32])

W_fc1 = weight_variable([14*14*32,768])
b_fc1 = bias_variable([768])
h_fc1 = tf.matmul(h_flat,W_fc1) + b_fc1

W_fc2 = weight_variable([768,10])
b_fc2 = bias_variable([10])

y_hat = tf.matmul(h_fc1,W_fc2) + b_fc2



cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_hat ))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_hat,1),tf.argmax(y,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(10000):

        batch_x,batch_y = mnist.train.next_batch(50)

        if i % 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={input:batch_x,y:batch_y})
            print("step %d,train accuracy %g " %(i,train_accuracy))

        train_step.run(feed_dict={input:batch_x,y:batch_y})

        # sess.run(train_step,feed_dict={x:batch_x,y:batch_y})

    print("test accuracy %g " % accuracy.eval(feed_dict={input:mnist.test.images,y:mnist.test.labels}))

运行结果:
Tensorflow学习之MNIST数字集识别
Tensorflow学习之MNIST数字集识别


推荐阅读
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 2018年人工智能大数据的爆发,学Java还是Python?
    本文介绍了2018年人工智能大数据的爆发以及学习Java和Python的相关知识。在人工智能和大数据时代,Java和Python这两门编程语言都很优秀且火爆。选择学习哪门语言要根据个人兴趣爱好来决定。Python是一门拥有简洁语法的高级编程语言,容易上手。其特色之一是强制使用空白符作为语句缩进,使得新手可以快速上手。目前,Python在人工智能领域有着广泛的应用。如果对Java、Python或大数据感兴趣,欢迎加入qq群458345782。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 基于layUI的图片上传前预览功能的2种实现方式
    本文介绍了基于layUI的图片上传前预览功能的两种实现方式:一种是使用blob+FileReader,另一种是使用layUI自带的参数。通过选择文件后点击文件名,在页面中间弹窗内预览图片。其中,layUI自带的参数实现了图片预览功能。该功能依赖于layUI的上传模块,并使用了blob和FileReader来读取本地文件并获取图像的base64编码。点击文件名时会执行See()函数。摘要长度为169字。 ... [详细]
  • 本文介绍了Java工具类库Hutool,该工具包封装了对文件、流、加密解密、转码、正则、线程、XML等JDK方法的封装,并提供了各种Util工具类。同时,还介绍了Hutool的组件,包括动态代理、布隆过滤、缓存、定时任务等功能。该工具包可以简化Java代码,提高开发效率。 ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • 本文介绍了高校天文共享平台的开发过程中的思考和规划。该平台旨在为高校学生提供天象预报、科普知识、观测活动、图片分享等功能。文章分析了项目的技术栈选择、网站前端布局、业务流程、数据库结构等方面,并总结了项目存在的问题,如前后端未分离、代码混乱等。作者表示希望通过记录和规划,能够理清思路,进一步完善该平台。 ... [详细]
  • 分享2款网站程序源码/主题等后门检测工具
    本文介绍了2款用于检测网站程序源码和主题中是否存在后门的工具,分别是WebShellkiller和D盾_Web查杀。WebShellkiller是一款支持webshell和暗链扫描的工具,采用多重检测引擎和智能检测模型,能够更精准地检测出已知和未知的后门文件。D盾_Web查杀则使用自行研发的代码分析引擎,能够分析更为隐藏的WebShell后门行为。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 前言:拿到一个案例,去分析:它该是做分类还是做回归,哪部分该做分类,哪部分该做回归,哪部分该做优化,它们的目标值分别是什么。再挑影响因素,哪些和分类有关的影响因素,哪些和回归有关的 ... [详细]
author-avatar
c33454059
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有