热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

[578]TensorFlow练习3:RNN,RecurrentNeuralNetworks

前文《使用Python实现神经网络》和《TensorFlow练习1:对评论进行分类》都是简单的Feed-forwardNeuralNetworks(FNN前向反馈神经网络)。而RN

前文《使用Python实现神经网络》和《TensorFlow练习1: 对评论进行分类》都是简单的Feed-forward Neural Networks(FNN/前向反馈神经网络) 。而RNN(Recurrent Neural Networks)循环神经网络要相对复杂,它引入了循环,能够处理数据有前后关系的问题,常用在自然语言处理上。

RNN介绍:

  • Wiki:Recurrent neural network
  • Understanding-LSTMs
  • 循环神经网络(RNN, Recurrent Neural Networks)介绍
  • 唇语识别论文:https://arxiv.org/pdf/1611.05358v1.pdf
  • 自己动手做聊天机器人教程(入门级)

RNN的目的使用来处理序列数据。在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能无力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的。RNNs之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

本帖在MNIST数据集上应用RNN,看看准确率和FNN相比有没有提高。

使用TensorFlow创建RNN

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
# tensorflow自带了MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data# 下载mnist数据集
mnist = input_data.read_data_sets('./data', one_hot=True)
# 数字(label)只能是0-9,神经网络使用10个出口节点就可以编码表示0-9;
# 1 -> [0,1.0,0,0,0,0,0,0,0] one_hot表示只有一个出口节点是hot
# 2 -> [0,0.1,0,0,0,0,0,0,0]
# 5 -> [0,0,0,0,0,1.0,0,0,0]# 一张图片是28*28,FNN是一次性把数据输入到网络,RNN把它分成块
chunk_size = 28
chunk_n = 28rnn_size = 256
n_output_layer = 10 # 输出层X = tf.placeholder('float', [None, chunk_n, chunk_size])
Y = tf.placeholder('float')# 定义待训练的神经网络
def recurrent_neural_network(data):layer = {'w_': tf.Variable(tf.random_normal([rnn_size, n_output_layer])),'b_': tf.Variable(tf.random_normal([n_output_layer]))}lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)data = tf.transpose(data, [1, 0, 2])data = tf.reshape(data, [-1, chunk_size])data = tf.split(data,chunk_n,0)# outputs, status = tf.nn.rnn(lstm_cell, data, dtype=tf.float32)outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])return ouput# 使用数据训练神经网络
def train_neural_network(X, Y):predict = recurrent_neural_network(X)cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predict,labels=Y))optimizer = tf.train.AdamOptimizer().minimize(cost_func)epochs = 13with tf.Session() as session:# session.run(tf.initialize_all_variables())session.run(tf.global_variables_initializer())epoch_loss = 0for epoch in range(epochs):for i in range(int(mnist.train.num_examples / batch_size)):x, y = mnist.train.next_batch(batch_size)x = x.reshape([batch_size, chunk_n, chunk_size])_, c = session.run([optimizer, cost_func], feed_dict={X: x, Y: y})epoch_loss += cprint(epoch,' : ',epoch_loss)correct = tf.equal(tf.argmax(predict, 1), tf.argmax(Y, 1))accuracy = tf.reduce_mean(tf.cast(correct, 'float'))print('准确率: ', accuracy.eval({X: mnist.test.images.reshape(-1, chunk_n, chunk_size), Y: mnist.test.labels}))# 每次使用100条数据进行训练
batch_size = 100
train_neural_network(X, Y)

执行结果:

0 : 198.8345406986773
1 : 262.8327137650922
2 : 305.82587507134303
3 : 338.99255684157833
4 : 364.6154695186997
5 : 387.91993296472356
6 : 406.719897063449
7 : 424.98356476955814
8 : 439.7244168409379
9 : 452.39767731912434
10 : 464.7693197474582
11 : 475.5920020698977
12 : 484.6569202870887
准确率: 0.9882

比FNN提高了3个百分点。


推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了如何在给定的有序字符序列中插入新字符,并保持序列的有序性。通过示例代码演示了插入过程,以及插入后的字符序列。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 个人学习使用:谨慎参考1Client类importcom.thoughtworks.gauge.Step;importcom.thoughtworks.gauge.T ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 学习Java异常处理之throws之抛出并捕获异常(9)
    任务描述本关任务:在main方法之外创建任意一个方法接收给定的两个字符串,把第二个字符串的长度减1生成一个整数值,输出第一个字符串长度是 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • 统一知识图谱学习和建议:更好地理解用户偏好
    本文介绍了一种将知识图谱纳入推荐系统的方法,以提高推荐的准确性和可解释性。与现有方法不同的是,本方法考虑了知识图谱的不完整性,并在知识图谱中传输关系信息,以更好地理解用户的偏好。通过大量实验,验证了本方法在推荐任务和知识图谱完成任务上的优势。 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
author-avatar
kuqu00
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有