热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

[578]TensorFlow练习3:RNN,RecurrentNeuralNetworks

前文《使用Python实现神经网络》和《TensorFlow练习1:对评论进行分类》都是简单的Feed-forwardNeuralNetworks(FNN前向反馈神经网络)。而RN

前文《使用Python实现神经网络》和《TensorFlow练习1: 对评论进行分类》都是简单的Feed-forward Neural Networks(FNN/前向反馈神经网络) 。而RNN(Recurrent Neural Networks)循环神经网络要相对复杂,它引入了循环,能够处理数据有前后关系的问题,常用在自然语言处理上。

RNN介绍:

  • Wiki:Recurrent neural network
  • Understanding-LSTMs
  • 循环神经网络(RNN, Recurrent Neural Networks)介绍
  • 唇语识别论文:https://arxiv.org/pdf/1611.05358v1.pdf
  • 自己动手做聊天机器人教程(入门级)

RNN的目的使用来处理序列数据。在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能无力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的。RNNs之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

本帖在MNIST数据集上应用RNN,看看准确率和FNN相比有没有提高。

使用TensorFlow创建RNN

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
# tensorflow自带了MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data# 下载mnist数据集
mnist = input_data.read_data_sets('./data', one_hot=True)
# 数字(label)只能是0-9,神经网络使用10个出口节点就可以编码表示0-9;
# 1 -> [0,1.0,0,0,0,0,0,0,0] one_hot表示只有一个出口节点是hot
# 2 -> [0,0.1,0,0,0,0,0,0,0]
# 5 -> [0,0,0,0,0,1.0,0,0,0]# 一张图片是28*28,FNN是一次性把数据输入到网络,RNN把它分成块
chunk_size = 28
chunk_n = 28rnn_size = 256
n_output_layer = 10 # 输出层X = tf.placeholder('float', [None, chunk_n, chunk_size])
Y = tf.placeholder('float')# 定义待训练的神经网络
def recurrent_neural_network(data):layer = {'w_': tf.Variable(tf.random_normal([rnn_size, n_output_layer])),'b_': tf.Variable(tf.random_normal([n_output_layer]))}lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)data = tf.transpose(data, [1, 0, 2])data = tf.reshape(data, [-1, chunk_size])data = tf.split(data,chunk_n,0)# outputs, status = tf.nn.rnn(lstm_cell, data, dtype=tf.float32)outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])return ouput# 使用数据训练神经网络
def train_neural_network(X, Y):predict = recurrent_neural_network(X)cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predict,labels=Y))optimizer = tf.train.AdamOptimizer().minimize(cost_func)epochs = 13with tf.Session() as session:# session.run(tf.initialize_all_variables())session.run(tf.global_variables_initializer())epoch_loss = 0for epoch in range(epochs):for i in range(int(mnist.train.num_examples / batch_size)):x, y = mnist.train.next_batch(batch_size)x = x.reshape([batch_size, chunk_n, chunk_size])_, c = session.run([optimizer, cost_func], feed_dict={X: x, Y: y})epoch_loss += cprint(epoch,' : ',epoch_loss)correct = tf.equal(tf.argmax(predict, 1), tf.argmax(Y, 1))accuracy = tf.reduce_mean(tf.cast(correct, 'float'))print('准确率: ', accuracy.eval({X: mnist.test.images.reshape(-1, chunk_n, chunk_size), Y: mnist.test.labels}))# 每次使用100条数据进行训练
batch_size = 100
train_neural_network(X, Y)

执行结果:

0 : 198.8345406986773
1 : 262.8327137650922
2 : 305.82587507134303
3 : 338.99255684157833
4 : 364.6154695186997
5 : 387.91993296472356
6 : 406.719897063449
7 : 424.98356476955814
8 : 439.7244168409379
9 : 452.39767731912434
10 : 464.7693197474582
11 : 475.5920020698977
12 : 484.6569202870887
准确率: 0.9882

比FNN提高了3个百分点。


推荐阅读
  • 自然语言处理(NLP)——LDA模型:对电商购物评论进行情感分析
    目录一、2020数学建模美赛C题简介需求评价内容提供数据二、解题思路三、LDA简介四、代码实现1.数据预处理1.1剔除无用信息1.1.1剔除掉不需要的列1.1.2找出无效评论并剔除 ... [详细]
  • python模块之正则
    re模块可以读懂你写的正则表达式根据你写的表达式去执行任务用re去操作正则正则表达式使用一些规则来检测一些字符串是否符合个人要求,从一段字符串中找到符合要求的内容。在 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 利用python爬取豆瓣电影Top250的相关信息,包括电影详情链接,图片链接,影片中文名,影片外国名,评分,评价数,概况,导演,主演,年份,地区,类别这12项内容,然后将爬取的信息写入Exce ... [详细]
  • Visual Studio Code (VSCode) 是一款功能强大的源代码编辑器,支持多种编程语言,具备丰富的扩展生态。本文将详细介绍如何在 macOS 上安装、配置并使用 VSCode。 ... [详细]
  • 解决问题:1、批量读取点云las数据2、点云数据读与写出3、csf滤波分类参考:https:github.comsuyunzzzCSF论文题目ÿ ... [详细]
  • 通过使用 `pandas` 库中的 `scatter_matrix` 函数,可以有效地绘制出多个特征之间的两两关系。该函数不仅能够生成散点图矩阵,还能通过参数如 `frame`、`alpha`、`c`、`figsize` 和 `ax` 等进行自定义设置,以满足不同的可视化需求。此外,`diagonal` 参数允许用户选择对角线上的图表类型,例如直方图或密度图,从而提供更多的数据洞察。 ... [详细]
  • HTML 页面中调用 JavaScript 函数生成随机数值并自动展示
    在HTML页面中,通过调用JavaScript函数生成随机数值,并将其自动展示在页面上。具体实现包括构建HTML页面结构,定义JavaScript函数以生成随机数,以及在页面加载时自动调用该函数并将结果呈现给用户。 ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 利用REM实现移动端布局的高效适配技巧
    在移动设备上实现高效布局适配时,使用rem单位已成为一种流行且有效的技术。本文将分享过去一年中使用rem进行布局适配的经验和心得。rem作为一种相对单位,能够根据根元素的字体大小动态调整,从而确保不同屏幕尺寸下的布局一致性。通过合理设置根元素的字体大小,开发者可以轻松实现响应式设计,提高用户体验。此外,文章还将探讨一些常见的问题和解决方案,帮助开发者更好地掌握这一技术。 ... [详细]
  • 大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式
    大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式 ... [详细]
  • 本文介绍了如何使用 Node.js 和 Express(4.x 及以上版本)构建高效的文件上传功能。通过引入 `multer` 中间件,可以轻松实现文件上传。首先,需要通过 `npm install multer` 安装该中间件。接着,在 Express 应用中配置 `multer`,以处理多部分表单数据。本文详细讲解了 `multer` 的基本用法和高级配置,帮助开发者快速搭建稳定可靠的文件上传服务。 ... [详细]
  • 如何将TS文件转换为M3U8直播流:HLS与M3U8格式详解
    在视频传输领域,MP4虽然常见,但在直播场景中直接使用MP4格式存在诸多问题。例如,MP4文件的头部信息(如ftyp、moov)较大,导致初始加载时间较长,影响用户体验。相比之下,HLS(HTTP Live Streaming)协议及其M3U8格式更具优势。HLS通过将视频切分成多个小片段,并生成一个M3U8播放列表文件,实现低延迟和高稳定性。本文详细介绍了如何将TS文件转换为M3U8直播流,包括技术原理和具体操作步骤,帮助读者更好地理解和应用这一技术。 ... [详细]
  • Python 序列图分割与可视化编程入门教程
    本文介绍了如何使用 Python 进行序列图的快速分割与可视化。通过一个实际案例,详细展示了从需求分析到代码实现的全过程。具体包括如何读取序列图数据、应用分割算法以及利用可视化库生成直观的图表,帮助非编程背景的用户也能轻松上手。 ... [详细]
  • 点云技术初探(三):PCL基础知识与学习路径指南本文首先介绍了点云库(PCL)的基本概念,PCL是一个在前人点云研究成果基础上发展而来的大型跨平台开源C++编程库,旨在为点云数据处理提供全面的支持。文章详细阐述了PCL的核心功能及其在三维数据处理、特征提取、分割与配准等方面的应用,并为初学者提供了系统的学习路径和资源推荐,帮助读者快速掌握PCL的使用方法。 ... [详细]
author-avatar
kuqu00
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有