热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow笔记:卷积神经网络用于MNIST识别

导入相关包由于代码是在jupyternotebook中实现的,下面的‘%matplotlibinline’命令用于将图画在该页面上,不用jupyternotebook的话删掉改行代码即

导入相关包

由于代码是在jupyter notebook中实现的,下面的‘%matplotlib inline’命令用于将图画在该页面上,不用jupyter notebook的话删掉改行代码即可。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline # delete this line if don't use jupyter notebook

from tensorflow.examples.tutorials.mnist import input_data

导入数据

该部分用于下载和导入MNIST数据集,通过tensorflow内置的模块实现,其中‘input_data/’是用于存放MNIST数据的相对路径。
下面的X和Y_都不是具体的值,它们只是在tensorflow中声明的占位符(placeholder),可以在tensorflow运行某一计算时赋予相应的值,值的类型和规模等参数都可以在tf.placeholder()中声明,比如在下面的代码中,我们需要告诉tensorflow,模型的输入X的类型是float32,并且规模(shape)是[None,28,28,1]的一个张量(tensor),其中,None在下面的代码中表示数据一个batch的大小,28、28、1分别表示输入图片的宽、高和通道数。由于我们模型 的输出有10个类别,所以Y_的shape是[None,10]。

## load data(input_data/)
mnist = input_data.read_data_sets('input_data/',one_hot=True,reshape=False)
X = tf.placeholder(tf.float32,[None,28,28,1]) # input of the model
Y_ = tf.placeholder(tf.float32,[None,10]) # output of the model
Extracting input_data/train-images-idx3-ubyte.gz
Extracting input_data/train-labels-idx1-ubyte.gz
Extracting input_data/t10k-images-idx3-ubyte.gz
Extracting input_data/t10k-labels-idx1-ubyte.gz

参数初始化

该部分用于整个模型中参数的初始化,包括权重和偏置项,用到tensorflow中的Variable(),前两层是卷积层,以W2为例,四个维度[5,5,K,L]表示该层卷积核是5*5的,输入通道有K(=32)个,输出通道有L(=64)个

## Parameters initialization
K = 32
L = 64
M = 1024

W1 = tf.Variable(tf.truncated_normal([5,5,1,K],stddev=0.1))
B1 = tf.Variable(tf.constant(0.1,tf.float32,[K]))
W2 = tf.Variable(tf.truncated_normal([5,5,K,L],stddev=0.1))
B2 = tf.Variable(tf.constant(0.1,tf.float32,[L]))
W3 = tf.Variable(tf.truncated_normal([7*7*L,M],stddev=0.1))
B3 = tf.Variable(tf.constant(0.1,tf.float32,[M]))
W4 = tf.Variable(tf.truncated_normal([M,10],stddev=0.1))
B4 = tf.Variable(tf.constant(0.1,tf.float32,[10]))

模型构建

模型包括两个卷积层和一个全连接层,最后接一个softmax层用于多分类,并在全连接层后加了dropout,需要注意的是,在最后一个卷积层和全连接层前需要先将输出reshape一个行向量,否则会导致全连接层矩阵相乘时维度不一致。

## Model structure
keep_prob = tf.placeholder(tf.float32) # dropout: keep prob

conv1 = tf.nn.relu(tf.nn.conv2d(X,W1,strides=[1,1,1,1],padding='SAME')+B1)
pool1 = tf.nn.max_pool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
conv2 = tf.nn.relu(tf.nn.conv2d(pool1,W2,strides=[1,1,1,1],padding='SAME')+B2)
pool2 = tf.nn.max_pool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
pool2_flat = tf.reshape(pool2,[-1,7*7*L])
fc1 = tf.nn.relu(tf.matmul(pool2_flat,W3)+B3)
fc1_drop = tf.nn.dropout(fc1,keep_prob)
Ylogits = tf.matmul(fc1_drop,W4)+B4
Y = tf.nn.softmax(Ylogits)

损失函数

代码中采用的是交叉信息熵损失(cross entropy)

## Loss function: cross entropy 
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits,labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy)*100.0

准确率

准确率等于预测正确的样本数目除以总测试样本数目,需要注意的是,第一行tf.equal()用于判断预测的类别和实际类别是否相同,相同返回True,不同返回False,是布尔型,因此在第二行代码计算准确率之前需要先转化成float型(通过tf.cast())函数实现。

## Accuracy
is_accuracy = tf.equal(tf.argmax(Y_,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(is_accuracy,tf.float32))

模型训练

该部分实现整个模型的训练和结果的保存,代码中应注意以下几点:
1. 每次迭代的学习率是变化的(指数衰减),因此学习率需要用lr = tf.placeholder(tf.float32)来声明。
2. 训练过程(包括参数初始化)都需要在Session()中执行,如:sess.run(init)表示运行参数初始化。
3. 训练相关数据(包括一个batch的训练样本、本次迭代的学习率、dropout中的keep_prob参数)和测试数据以字典形式通过feed_dict参数传给session。

## Model training
epoch = 1000
batch = 100

# some lists used for saving results
train_acc = []
train_loss = []
test_acc = []
test_loss = []

lr = tf.placeholder(tf.float32) # learning rate (variable)
optimizer = tf.train.AdamOptimizer(lr) # optimal method
train_step = optimizer.minimize(cross_entropy)

init = tf.global_variables_initializer() # initialize all variables

with tf.Session() as sess:
sess.run(init) # run initialization process
for i in xrange(epoch):
# learning rate
max_lr = 0.003
min_lr = 0.0001
decay_speed = 2000.0
learning_rate = min_lr+(max_lr-min_lr)*np.math.exp(-i/decay_speed)

# training step
batch_X,batch_Y = mnist.train.next_batch(batch) # load one batch training data
train_data = {X:batch_X, Y_:batch_Y,lr:learning_rate, keep_prob:0.5} # dictionary
sess.run(train_step,feed_dict=train_data) # train one step

# save results
acc,loss = sess.run([accuracy,cross_entropy],feed_dict=train_data)
train_acc.append(acc)
train_loss.append(loss)

test_data = {X:mnist.test.images,Y_:mnist.test.labels,keep_prob:1}
acc,loss = sess.run([accuracy,cross_entropy],feed_dict=test_data)
test_acc.append(acc)
test_loss.append(loss)

# print training process
if i%100 == 0:
print "epoch = %d, " %i, "test accuracy = %.4f," %test_acc[i], \
"test loss = %.6f" %test_loss[i], "learning rate = %.6f" %learning_rate

print "test accuracy = %.4f, " %test_acc[-1], "test loss = %.6f" %test_loss[-1]
epoch = 0,  test accuracy = 0.1135, test loss = 3775.333008 learning rate = 0.003000
epoch = 100, test accuracy = 0.9294, test loss = 22.953661 learning rate = 0.002859
epoch = 200, test accuracy = 0.9523, test loss = 15.348906 learning rate = 0.002724
epoch = 300, test accuracy = 0.9624, test loss = 11.892620 learning rate = 0.002596
epoch = 400, test accuracy = 0.9683, test loss = 9.448619 learning rate = 0.002474
epoch = 500, test accuracy = 0.9740, test loss = 7.751043 learning rate = 0.002359
epoch = 600, test accuracy = 0.9741, test loss = 7.173435 learning rate = 0.002248
epoch = 700, test accuracy = 0.9811, test loss = 5.837584 learning rate = 0.002144
epoch = 800, test accuracy = 0.9800, test loss = 5.772215 learning rate = 0.002044
epoch = 900, test accuracy = 0.9765, test loss = 7.619859 learning rate = 0.001949
test accuracy = 0.9833, test loss = 5.115160

训练过程可视化

包括训练过程中准确率和损失的变化情况

## Result visualization
# training accuracy and test accuracy
plt.figure()
plt.plot(train_acc,'r',label='train_acc')
plt.plot(test_acc,'b',label='test_acc')
plt.legend()
plt.axis([0,epoch,0,1])
plt.show()

# training loss and test loss
plt.figure()
plt.plot(train_loss,'r',label='train_loss')
plt.plot(test_loss,'b',label='test_loss')
plt.legend()
plt.axis([0,epoch,0,100])
plt.show()


accuracy
loss

参考文献
  1. Tensorflow and deep learning - without a PhD(需要翻墙)
  2. Tensorflow and deep learning - without a PhD视频+PPT
  3. Tensorflow and deep learning - without a PhD代码
  4. Tensorflow and deep learning - without a PhD翻译


推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 普通树(每个节点可以有任意数量的子节点)级序遍历 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 中国学者实现 CNN 全程可视化,详尽展示每次卷积、ReLU 和池化过程 ... [详细]
  • 深入解析经典卷积神经网络及其实现代码
    深入解析经典卷积神经网络及其实现代码 ... [详细]
  • 本文介绍了 Python 中的基本数据类型,包括不可变数据类型(数字、字符串、元组)和可变数据类型(列表、字典、集合),并详细解释了每种数据类型的使用方法和常见操作。 ... [详细]
  • 事件是程序各部分之间的一种通信方式,也是异步编程的一种实现形式。本文将详细介绍EventTarget接口及其相关方法,以及如何使用监听函数处理事件。 ... [详细]
  • WinMain 函数详解及示例
    本文详细介绍了 WinMain 函数的参数及其用途,并提供了一个具体的示例代码来解析 WinMain 函数的实现。 ... [详细]
  • 单片微机原理P3:80C51外部拓展系统
      外部拓展其实是个相对来说很好玩的章节,可以真正开始用单片机写程序了,比较重要的是外部存储器拓展,81C55拓展,矩阵键盘,动态显示,DAC和ADC。0.IO接口电路概念与存 ... [详细]
  • PTArchiver工作原理详解与应用分析
    PTArchiver工作原理及其应用分析本文详细解析了PTArchiver的工作机制,探讨了其在数据归档和管理中的应用。PTArchiver通过高效的压缩算法和灵活的存储策略,实现了对大规模数据的高效管理和长期保存。文章还介绍了其在企业级数据备份、历史数据迁移等场景中的实际应用案例,为用户提供了实用的操作建议和技术支持。 ... [详细]
  • 在探讨如何在Android的TextView中实现多彩文字与多样化字体效果时,本文提供了一种不依赖HTML技术的解决方案。通过使用SpannableString和相关的Span类,开发者可以轻松地为文本添加丰富的样式和颜色,从而提升用户体验。文章详细介绍了实现过程中的关键步骤和技术细节,帮助开发者快速掌握这一技巧。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。 ... [详细]
  • 浅层神经网络解析:本文详细探讨了两层神经网络(即一个输入层、一个隐藏层和一个输出层)的结构与工作原理。通过吴恩达教授的课程,读者将深入了解浅层神经网络的基本概念、参数初始化方法以及前向传播和反向传播的具体实现步骤。此外,文章还介绍了如何利用这些基础知识解决实际问题,并提供了丰富的实例和代码示例。 ... [详细]
  • 理工科男女不容错过的神奇资源网站
    十一长假即将结束,你的假期学习计划进展如何?无论你是在家中、思念家乡,还是身处异国他乡,理工科学生都不容错过一些神奇的资源网站。这些网站提供了丰富的学术资料、实验数据和技术文档,能够帮助你在假期中高效学习和提升专业技能。 ... [详细]
author-avatar
I-1ove-Y0u
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有