热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于TensorFlow一次简单的RNN实现

由于发现网上大部分tensorflow的RNN教程都过于简答或者复杂,所以尝试一下从简单到深的在TF中写出RNN代码,这篇文章主要参考打是TensorFlow人工智能引擎入门教程之

由于发现网上大部分tensorflow的RNN教程都过于简答或者复杂,所以尝试一下从简单到深的在TF中写出RNN代码,这篇文章主要参考打是TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用中使用的代码,但是由于代码版本较为古老,所以TF报错,参考解读tensorflow之rnn 对代码进行修改和实现,第一版实现来一个最简单打RNN模型。

RNN原理见参考资料

由于本次实验在jupyter中完成的,所以部分图片和输出不好更如知乎中,好一点的版本见:RNNStudy/simpleRNN.ipynb

记录步骤如下:

引入相关包

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
#from tensorflow.nn import rnn, rnn_cell
import numpy as np

先来看输入数据,本次用打输入数据是MNIST打数据可以看到如下

print '输入数据:'
print mnist.train.images
print '输入数据打shape:'
print mnist.train.images.shape

可以看到其中784是图据28×28像素打图像,将其转化成图像观察一下如下图所示,

%pylab inline
%matplotlib inline
import pylab
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

如果我们要用RNN来训练这个网络打话,则应该选择n_input = 28 ,n_steps = 28结构

a= np.asarray(range(20))
b = a.reshape(-1,2,2)
print '生成一列数据'
print a
print 'reshape函数的效果'
print b
c = np.transpose(b,[1,0,2])
d = c.reshape(-1,2)
print '--------c-----------'
print c
print '--------d-----------'
print d

定义一些模型打参数

''' To classify images using a reccurent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample. '''
# Parameters
learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 100
# Network Parameters
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)

构建RNN打函数可以参考 :Neural Network开始我们先创建两个占位符placeholder,基本使用可以参考官方文档:基本使用 – TensorFlow 官方文档中文版

# tf Graph input
x = tf.placeholder("float32", [None, n_steps, n_input])
# Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
y = tf.placeholder("float32", [None, n_classes])
# Define weights
weights = {
'hidden': tf.Variable(tf.random_normal([n_input, n_hidden])), # Hidden layer weights
'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([n_hidden])),
'out': tf.Variable(tf.random_normal([n_classes]))
}

首先创建一个CELL这里需要打一个参数是隐藏单元打个数n_hidden,在创建完成后对其进行初始化

这里会造成一个BUG,后面说道

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=0.0, state_is_tuple=True)
_state = lstm_cell.zero_state(batch_size,tf.float32)

为了使得 原始数据打输入和模型匹配,我们对数据进行一系列变换,变换打结果如下,数据变化可以参考上面打小实验

a1 = tf.transpose(x, [1, 0, 2])
a2 = tf.reshape(a1, [-1, n_input])
a3 = tf.matmul(a2, weights['hidden']) + biases['hidden']
a4 = tf.split(0, n_steps, a3)
print '-----------------------'
print 'a1:'
print a1
print '-----------------------'
print 'a2:'
print a2
print '-----------------------'
print 'a3:'
print a3
print '-----------------------'
print 'a4:'
print a4

为了使得 原始数据打输入和模型匹配,我们对数据进行一系列变换,变换打结果如下这里主要是为了匹配tf.nn.rnn遮盖函数,函数可参考官方文档:Neural Network或者前面解读RNN那篇解读tensorflow之rnn

outputs, states = tf.nn.rnn(lstm_cell, a4, initial_state = _state)
print 'outputs[-1]'
print outputs[-1]
print '-----------------------'
a5 = tf.matmul(outputs[-1], weights['out']) + biases['out']
print 'a5:'
print a5
print '-----------------------'

定义cost,使用梯度下降求最优

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(a5, y))
#AdamOptimizer
#optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
correct_pred = tf.equal(tf.argmax(a5,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.initialize_all_variables()

进行模型训练,这里需要注意,由于我使用打是Jupyter,采取来交互式环境,所以在普通py中sess = tf.InteractiveSession() 这一句不一定正确,需要自己修改为tf.Session()

sess = tf.InteractiveSession()
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_xs = batch_xs.reshape((batch_size, n_steps, n_input))
# Fit training using batch data
sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys})
print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + ", Training Accuracy= " + "{:.5f}".format(acc)
step += 1
print "Optimization Finished!"

测试模型准确率

test_len = batch_size
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
# Evaluate model
correct_pred = tf.equal(tf.argmax(a5,1), tf.argmax(y,1))
print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label})

在这里测试准确率有一个BUG,test_len 必须和batch_size相等,这是由于前面在初始化模型打时候选择batch_size作为参数,导致a5输出一直是一个batch_size行打矩阵,若est_len 和batch_size不想等,accuracy计算会报错。 由于暂时没想到简单打解决方法,所以待下次处理。

python参考资料:

解读tensorflow之rnn

RNN以及LSTM的介绍和公式梳理

TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用

LSTM模型理论总结(产生、发展和性能等)

解析Tensorflow官方PTB模型的demo


推荐阅读
  • MySQL Decimal 类型的最大值解析及其在数据处理中的应用艺术
    在关系型数据库中,表的设计与SQL语句的编写对性能的影响至关重要,甚至可占到90%以上。本文将重点探讨MySQL中Decimal类型的最大值及其在数据处理中的应用技巧,通过实例分析和优化建议,帮助读者深入理解并掌握这一重要知识点。 ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • POJ 2482 星空中的星星:利用线段树与扫描线算法解决
    在《POJ 2482 星空中的星星》问题中,通过运用线段树和扫描线算法,可以高效地解决星星在窗口内的计数问题。该方法不仅能够快速处理大规模数据,还能确保时间复杂度的最优性,适用于各种复杂的星空模拟场景。 ... [详细]
  • 兆芯X86 CPU架构的演进与现状(国产CPU系列)
    本文详细介绍了兆芯X86 CPU架构的发展历程,从公司成立背景到关键技术授权,再到具体芯片架构的演进,全面解析了兆芯在国产CPU领域的贡献与挑战。 ... [详细]
  • 本文详细介绍了如何使用Python的多进程技术来高效地分块读取超大文件,并将其输出为多个文件。通过这种方式,可以显著提高读取速度和处理效率。 ... [详细]
  • 本文节选自《NLTK基础教程——用NLTK和Python库构建机器学习应用》一书的第1章第1.2节,作者Nitin Hardeniya。本文将带领读者快速了解Python的基础知识,为后续的机器学习应用打下坚实的基础。 ... [详细]
  • JUC(三):深入解析AQS
    本文详细介绍了Java并发工具包中的核心类AQS(AbstractQueuedSynchronizer),包括其基本概念、数据结构、源码分析及核心方法的实现。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • 本文详细介绍了在 Ubuntu 系统上搭建 Hadoop 集群时遇到的 SSH 密钥认证问题及其解决方案。通过本文,读者可以了解如何在多台虚拟机之间实现无密码 SSH 登录,从而顺利启动 Hadoop 集群。 ... [详细]
  • 在多线程并发环境中,普通变量的操作往往是线程不安全的。本文通过一个简单的例子,展示了如何使用 AtomicInteger 类及其核心的 CAS 无锁算法来保证线程安全。 ... [详细]
  • 本文详细介绍了 PHP 中对象的生命周期、内存管理和魔术方法的使用,包括对象的自动销毁、析构函数的作用以及各种魔术方法的具体应用场景。 ... [详细]
  • 开机自启动的几种方式
    0x01快速自启动目录快速启动目录自启动方式源于Windows中的一个目录,这个目录一般叫启动或者Startup。位于该目录下的PE文件会在开机后进行自启动 ... [详细]
  • DVWA学习笔记系列:深入理解CSRF攻击机制
    DVWA学习笔记系列:深入理解CSRF攻击机制 ... [详细]
  • 在Linux系统中,网络配置是至关重要的任务之一。本文详细解析了Firewalld和Netfilter机制,并探讨了iptables的应用。通过使用`ip addr show`命令来查看网卡IP地址(需要安装`iproute`包),当网卡未分配IP地址或处于关闭状态时,可以通过`ip link set`命令进行配置和激活。此外,文章还介绍了如何利用Firewalld和iptables实现网络流量控制和安全策略管理,为系统管理员提供了实用的操作指南。 ... [详细]
  • 经过两天的努力,终于成功解决了半平面交模板题POJ3335的问题。原来是在`OnLeft`函数中漏掉了关键的等于号。通过这次训练,不仅加深了对半平面交算法的理解,还提升了调试和代码实现的能力。未来将继续深入研究计算几何的其他核心问题,进一步巩固和拓展相关知识。 ... [详细]
author-avatar
李长倩63399
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有