热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

[578]TensorFlow练习3:RNN,RecurrentNeuralNetworks

前文《使用Python实现神经网络》和《TensorFlow练习1:对评论进行分类》都是简单的Feed-forwardNeuralNetworks(FNN前向反馈神经网络)。而RN

前文《使用Python实现神经网络》和《TensorFlow练习1: 对评论进行分类》都是简单的Feed-forward Neural Networks(FNN/前向反馈神经网络) 。而RNN(Recurrent Neural Networks)循环神经网络要相对复杂,它引入了循环,能够处理数据有前后关系的问题,常用在自然语言处理上。

RNN介绍:

  • Wiki:Recurrent neural network
  • Understanding-LSTMs
  • 循环神经网络(RNN, Recurrent Neural Networks)介绍
  • 唇语识别论文:https://arxiv.org/pdf/1611.05358v1.pdf
  • 自己动手做聊天机器人教程(入门级)

RNN的目的使用来处理序列数据。在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能无力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的。RNNs之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

本帖在MNIST数据集上应用RNN,看看准确率和FNN相比有没有提高。

使用TensorFlow创建RNN

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
# tensorflow自带了MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data# 下载mnist数据集
mnist = input_data.read_data_sets('./data', one_hot=True)
# 数字(label)只能是0-9,神经网络使用10个出口节点就可以编码表示0-9;
# 1 -> [0,1.0,0,0,0,0,0,0,0] one_hot表示只有一个出口节点是hot
# 2 -> [0,0.1,0,0,0,0,0,0,0]
# 5 -> [0,0,0,0,0,1.0,0,0,0]# 一张图片是28*28,FNN是一次性把数据输入到网络,RNN把它分成块
chunk_size = 28
chunk_n = 28rnn_size = 256
n_output_layer = 10 # 输出层X = tf.placeholder('float', [None, chunk_n, chunk_size])
Y = tf.placeholder('float')# 定义待训练的神经网络
def recurrent_neural_network(data):layer = {'w_': tf.Variable(tf.random_normal([rnn_size, n_output_layer])),'b_': tf.Variable(tf.random_normal([n_output_layer]))}lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)data = tf.transpose(data, [1, 0, 2])data = tf.reshape(data, [-1, chunk_size])data = tf.split(data,chunk_n,0)# outputs, status = tf.nn.rnn(lstm_cell, data, dtype=tf.float32)outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])return ouput# 使用数据训练神经网络
def train_neural_network(X, Y):predict = recurrent_neural_network(X)cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predict,labels=Y))optimizer = tf.train.AdamOptimizer().minimize(cost_func)epochs = 13with tf.Session() as session:# session.run(tf.initialize_all_variables())session.run(tf.global_variables_initializer())epoch_loss = 0for epoch in range(epochs):for i in range(int(mnist.train.num_examples / batch_size)):x, y = mnist.train.next_batch(batch_size)x = x.reshape([batch_size, chunk_n, chunk_size])_, c = session.run([optimizer, cost_func], feed_dict={X: x, Y: y})epoch_loss += cprint(epoch,' : ',epoch_loss)correct = tf.equal(tf.argmax(predict, 1), tf.argmax(Y, 1))accuracy = tf.reduce_mean(tf.cast(correct, 'float'))print('准确率: ', accuracy.eval({X: mnist.test.images.reshape(-1, chunk_n, chunk_size), Y: mnist.test.labels}))# 每次使用100条数据进行训练
batch_size = 100
train_neural_network(X, Y)

执行结果:

0 : 198.8345406986773
1 : 262.8327137650922
2 : 305.82587507134303
3 : 338.99255684157833
4 : 364.6154695186997
5 : 387.91993296472356
6 : 406.719897063449
7 : 424.98356476955814
8 : 439.7244168409379
9 : 452.39767731912434
10 : 464.7693197474582
11 : 475.5920020698977
12 : 484.6569202870887
准确率: 0.9882

比FNN提高了3个百分点。


推荐阅读
  • 深入理解Tornado模板系统
    本文详细介绍了Tornado框架中模板系统的使用方法。Tornado自带的轻量级、高效且灵活的模板语言位于tornado.template模块,支持嵌入Python代码片段,帮助开发者快速构建动态网页。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • python image stiching_Python自然语言处理,词云图生成
    自然语言处理本节介绍如何使用Python中的库,生成词云图,涉及自然语言处理的相关问题,自然语言处理是计算机科学领域与人工智能领域中的一个 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • 本文介绍如何使用 Python 提取和替换 .docx 文件中的图片。.docx 文件本质上是压缩文件,通过解压可以访问其中的图片资源。此外,我们还将探讨使用第三方库 docx 的方法来简化这一过程。 ... [详细]
  • 本文探讨了如何在一个Python脚本中定义一个方法来生成特定URL,并在Robot Framework测试环境中调用此方法,通过环境变量启动测试案例。文中还提供了一个具体的实例,展示了正确的调用方式及可能遇到的问题解决方案。 ... [详细]
  • 自然语言处理(NLP)——LDA模型:对电商购物评论进行情感分析
    目录一、2020数学建模美赛C题简介需求评价内容提供数据二、解题思路三、LDA简介四、代码实现1.数据预处理1.1剔除无用信息1.1.1剔除掉不需要的列1.1.2找出无效评论并剔除 ... [详细]
  • Python默认字符解析:深入理解Python中的字符串处理
    在Python中,字符串是编程中最基本且常用的数据类型之一。尽管许多初学者是从C语言开始接触字符串,通常通过经典的“Hello, World!”程序入门,但Python对字符串的处理方式更为灵活和强大。本文将深入探讨Python中的字符串处理机制,包括字符串的创建、操作、格式化以及编码解码等方面,帮助读者全面理解Python字符串的特性和应用。 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 本文详细介绍 Go+ 编程语言中的上下文处理机制,涵盖其基本概念、关键方法及应用场景。Go+ 是一门结合了 Go 的高效工程开发特性和 Python 数据科学功能的编程语言。 ... [详细]
  • 本文介绍了如何在C#中启动一个应用程序,并通过枚举窗口来获取其主窗口句柄。当使用Process类启动程序时,我们通常只能获得进程的句柄,而主窗口句柄可能为0。因此,我们需要使用API函数和回调机制来准确获取主窗口句柄。 ... [详细]
author-avatar
kuqu00
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有