热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于TensorFlow一次简单的RNN实现

由于发现网上大部分tensorflow的RNN教程都过于简答或者复杂,所以尝试一下从简单到深的在TF中写出RNN代码,这篇文章主要参考打是TensorFlow人工智能引擎入门教程之

由于发现网上大部分tensorflow的RNN教程都过于简答或者复杂,所以尝试一下从简单到深的在TF中写出RNN代码,这篇文章主要参考打是TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用中使用的代码,但是由于代码版本较为古老,所以TF报错,参考解读tensorflow之rnn 对代码进行修改和实现,第一版实现来一个最简单打RNN模型。

RNN原理见参考资料

由于本次实验在jupyter中完成的,所以部分图片和输出不好更如知乎中,好一点的版本见:RNNStudy/simpleRNN.ipynb

记录步骤如下:

引入相关包

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
#from tensorflow.nn import rnn, rnn_cell
import numpy as np

先来看输入数据,本次用打输入数据是MNIST打数据可以看到如下

print '输入数据:'
print mnist.train.images
print '输入数据打shape:'
print mnist.train.images.shape

可以看到其中784是图据28×28像素打图像,将其转化成图像观察一下如下图所示,

%pylab inline
%matplotlib inline
import pylab
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

如果我们要用RNN来训练这个网络打话,则应该选择n_input = 28 ,n_steps = 28结构

a= np.asarray(range(20))
b = a.reshape(-1,2,2)
print '生成一列数据'
print a
print 'reshape函数的效果'
print b
c = np.transpose(b,[1,0,2])
d = c.reshape(-1,2)
print '--------c-----------'
print c
print '--------d-----------'
print d

定义一些模型打参数

''' To classify images using a reccurent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample. '''
# Parameters
learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 100
# Network Parameters
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)

构建RNN打函数可以参考 :Neural Network开始我们先创建两个占位符placeholder,基本使用可以参考官方文档:基本使用 – TensorFlow 官方文档中文版

# tf Graph input
x = tf.placeholder("float32", [None, n_steps, n_input])
# Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
y = tf.placeholder("float32", [None, n_classes])
# Define weights
weights = {
'hidden': tf.Variable(tf.random_normal([n_input, n_hidden])), # Hidden layer weights
'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([n_hidden])),
'out': tf.Variable(tf.random_normal([n_classes]))
}

首先创建一个CELL这里需要打一个参数是隐藏单元打个数n_hidden,在创建完成后对其进行初始化

这里会造成一个BUG,后面说道

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=0.0, state_is_tuple=True)
_state = lstm_cell.zero_state(batch_size,tf.float32)

为了使得 原始数据打输入和模型匹配,我们对数据进行一系列变换,变换打结果如下,数据变化可以参考上面打小实验

a1 = tf.transpose(x, [1, 0, 2])
a2 = tf.reshape(a1, [-1, n_input])
a3 = tf.matmul(a2, weights['hidden']) + biases['hidden']
a4 = tf.split(0, n_steps, a3)
print '-----------------------'
print 'a1:'
print a1
print '-----------------------'
print 'a2:'
print a2
print '-----------------------'
print 'a3:'
print a3
print '-----------------------'
print 'a4:'
print a4

为了使得 原始数据打输入和模型匹配,我们对数据进行一系列变换,变换打结果如下这里主要是为了匹配tf.nn.rnn遮盖函数,函数可参考官方文档:Neural Network或者前面解读RNN那篇解读tensorflow之rnn

outputs, states = tf.nn.rnn(lstm_cell, a4, initial_state = _state)
print 'outputs[-1]'
print outputs[-1]
print '-----------------------'
a5 = tf.matmul(outputs[-1], weights['out']) + biases['out']
print 'a5:'
print a5
print '-----------------------'

定义cost,使用梯度下降求最优

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(a5, y))
#AdamOptimizer
#optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
correct_pred = tf.equal(tf.argmax(a5,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.initialize_all_variables()

进行模型训练,这里需要注意,由于我使用打是Jupyter,采取来交互式环境,所以在普通py中sess = tf.InteractiveSession() 这一句不一定正确,需要自己修改为tf.Session()

sess = tf.InteractiveSession()
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_xs = batch_xs.reshape((batch_size, n_steps, n_input))
# Fit training using batch data
sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys})
print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + ", Training Accuracy= " + "{:.5f}".format(acc)
step += 1
print "Optimization Finished!"

测试模型准确率

test_len = batch_size
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
# Evaluate model
correct_pred = tf.equal(tf.argmax(a5,1), tf.argmax(y,1))
print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label})

在这里测试准确率有一个BUG,test_len 必须和batch_size相等,这是由于前面在初始化模型打时候选择batch_size作为参数,导致a5输出一直是一个batch_size行打矩阵,若est_len 和batch_size不想等,accuracy计算会报错。 由于暂时没想到简单打解决方法,所以待下次处理。

python参考资料:

解读tensorflow之rnn

RNN以及LSTM的介绍和公式梳理

TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用

LSTM模型理论总结(产生、发展和性能等)

解析Tensorflow官方PTB模型的demo


推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文介绍了使用cacti监控mssql 2005运行资源情况的操作步骤,包括安装必要的工具和驱动,测试mssql的连接,配置监控脚本等。通过php连接mssql来获取SQL 2005性能计算器的值,实现对mssql的监控。详细的操作步骤和代码请参考附件。 ... [详细]
  • 本文介绍了一个适用于PHP应用快速接入TRX和TRC20数字资产的开发包,该开发包支持使用自有Tron区块链节点的应用场景,也支持基于Tron官方公共API服务的轻量级部署场景。提供的功能包括生成地址、验证地址、查询余额、交易转账、查询最新区块和查询交易信息等。详细信息可参考tron-php的Github地址:https://github.com/Fenguoz/tron-php。 ... [详细]
  • 本文介绍了在rhel5.5操作系统下搭建网关+LAMP+postfix+dhcp的步骤和配置方法。通过配置dhcp自动分配ip、实现外网访问公司网站、内网收发邮件、内网上网以及SNAT转换等功能。详细介绍了安装dhcp和配置相关文件的步骤,并提供了相关的命令和配置示例。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • 使用在线工具jsonschema2pojo根据json生成java对象
    本文介绍了使用在线工具jsonschema2pojo根据json生成java对象的方法。通过该工具,用户只需将json字符串复制到输入框中,即可自动将其转换成java对象。该工具还能解析列表式的json数据,并将嵌套在内层的对象也解析出来。本文以请求github的api为例,展示了使用该工具的步骤和效果。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
author-avatar
李长倩63399
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有