热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Torch中的RNN底层代码实现

理论篇代码篇Torch中的RNN【1】这个package包括了RNN,RL,通过这个package可以很容易构建RNN,RL的模型。安装

  • 理论篇
  • 代码篇

Torch中的RNN【1】这个package包括了RNN,RL,通过这个package可以很容易构建RNN,RL的模型。

安装:

luarocks install torch
luarocks install nn
luarocks install torchx
luarocks install dataload如果有CUDA:
luarocks install cutorch
luarocks install cunn记得安装:
luarocks install rnn

但是如果要使用nn.Reccurent,需要安装:【4】

理论篇

这一次主要是讲最简单的RNN,也就是Simple RNN。实现的话是根据这两篇论文:【6】,【7】

首先介绍一下Simple RNN的整个网络结构,再说一下ρ step 的BPTT。

整个网络可以用下图来表示:(这种网络的输入一部分是当前的输入,另外一部分来自于hidden layer的上一个输出,这种叫做Elman Network。另外一种网络是一部分来自于当前输入,另外一部分来自于整个网络的上一个输出)
这里写图片描述

  • 当前输入wt与上一个hidden layer的输出st1两个vector相加,得到真正输入到网络里面的东西。
    这里写图片描述
  • 接着是把输入送进一个logistic regression里面,得到hidden layer:st. st 一方面往输出那条路径走,另外一方面往缓存或者叫做Context里面存起来,称为下一个输入需要的一部分,替换st1
    这里写图片描述
    这里写图片描述
  • st输出该时刻的output: yt 一个Linear加上softmax,非常简单。
    这里写图片描述
    这里写图片描述

这样呢就把整个网络结构描述完了,接下来就是如何训练得到参数了。(其实RNN,LSTM还有很多小的trick,同样的算法,trick不一样,结果都会千差万别)

在另外的论文里面把这幅图给完整画了出来,显得更加清晰:
这里写图片描述

了解了整个网络的以后,需要定义loss,在进行BP的时候,首先定义loss function,一般采用的是SSE:dpk就是第p个sample,输出的feature第k个的label。y是prediction。
这里写图片描述

对于w的更新,都是采用梯度下降:
这里写图片描述

对于输出output部分进行求导:
这里写图片描述

再进一步输出output的linear regression部分的w进行求导:
这里写图片描述

接着是hidden layer进行求导:
这里写图片描述

对hidden layer的输入的input部分参数进行求导:
这里写图片描述

对hidden layer的上一个hidden layer的作为input的部分进行求导:
这里写图片描述

目前的loss为SSE的时候,一般采用logistic function作为输出的函数:
这里写图片描述
这里写图片描述

当然,也可以有别的loss function,对应的output function g也会做乡相应改变。

比如对于Gaussian Distribution:
这里写图片描述
这里写图片描述
这里写图片描述

使用cross-entropy作为loss:(g为logistic function)
这里写图片描述
这里写图片描述

对于分类问题,采用multinomial dostribution:
采用的是softmax作为g,然后loss function为:
这里写图片描述

这里写图片描述

这里写图片描述

在RNN中经常听到BPTT,就是让RNN在进行后向传递的时候不仅仅是当前这个时期,还有的是更多时刻:τ。

比如τ = 3,展开的图如下:展开了三次,那么进行BP的时候,就把各个参数往后相乘来更新w,这里需要注意vanishing gradient effect和explode gradient effect的东西,一个梯度衰减比如为0,一个梯度爆炸。
这里写图片描述

还有一种图可以表示梯度变化,红色的表示梯度的方向:
这里写图片描述

如果用公式来表示这个整个过程就是前向:
这里写图片描述

后向更新梯度:(每个时刻的梯度都会进行叠加到最后的w更新)
这里写图片描述

代码篇

这次描述的是Simple RNN,函数为nn.Recurrent 。在nn中有两个抽象类,一个是nn,用来构建网络,一个是Criterion【3】,用来提供比如cross entropy,reward。具体介绍可以看【2】,还有对应的论文。

在【3】中提出了一个简单的例子:目前下面的nn.Recurrent已经不在Torch的库中,所以要使用的话,就去安装这个人写的【4】

这里面的实现的RNN是最简单的,连hidden layer都没有,直接transfer就是输出了。

nn.Recurrent(start, input, feedback, [transfer, rho, merge])
-- start:对初始t=0的input进行处理
-- input:对t~=0的时候input进行处理
-- feedback:对s(t)进行处理缓存到了s(t-1)
-- transfer:对输出进行处理的函数
-- rho:进行BPTT的steps的数目
-- merge:对input x(t)和上一个时刻的输出s(t-1)进行融合

-- generate some dummy inputs and gradOutputs sequences
-- 生成dummy input
inputs, gradOutputs = {}, {}
for step=1,rho doinputs[step] = torch.randn(batchSize,inputSize)gradOutputs[step] = torch.randn(batchSize,inputSize)
end-- 调用RNN
-- an AbstractRecurrent instance
rnn = nn.Recurrent(hiddenSize, -- size of the input layer(隐层的size)nn.Linear(inputSize,outputSize), -- input layer(输入层进行linear regression) nn.Linear(outputSize, outputSize), -- recurrent layer 输出层的linear regressionnn.Sigmoid(), -- transfer function,把输入通过linear regression之后的结果送到这个函数得到s(t),这个函数也可以改成ReLU别的激活函数rho -- maximum number of time-steps for BPTT,进行BPTT时候的steps
)-- feed-forward and backpropagate through time like this :
for step=1,rho
dornn:forward(inputs[step])rnn:backward(inputs[step], gradOutputs[step])
endrnn:backwardThroughTime() -- call backward on the internal modules
gradInputs = rnn.gradInputs
rnn:updateParameters(0.1)
rnn:forget() -- resets the time-step counter

对完整的nn.Reccurent的理解:【10】

assert(not nn.Recurrent, "update nnx package : luarocks install nnx")
local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent')-- 把各个module放到RNN的对应位置
-- start是对最开始t=0输入inut做的处理
-- input是对t~=0的时刻进行input的处理
-- feedback是对s(t)进行处理缓存到s(t-1)的函数
-- transfer是对最后的输出的activation function
-- rho:是进行BPTT的时间
-- merge:对于输入x(t)和上一个时刻的hidden layer的输出s(t-1)的融合方法
function Recurrent:__init(start, input, feedback, transfer, rho, merge)parent.__init(self, rho)local ts = torch.type(start)if ts == 'torch.LongStorage' or ts == 'number' thenstart = nn.Add(start)elseif ts == 'table' thenstart = nn.Add(torch.LongStorage(start))elseif not torch.isTypeOf(start, 'nn.Module') thenerror"Recurrent : expecting arg 1 of type nn.Module, torch.LongStorage, number or table"endself.startModule = startself.inputModule = inputself.feedbackModule = feedbackself.transferModule = transfer or nn.Sigmoid()self.mergeModule = merge or nn.CAddTable()self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule}self:buildInitialModule()self:buildRecurrentModule()self.sharedClones[2] = self.recurrentModule
end-- 对最开始t=0的时候构建模型
-- build module used for the first step (steps == 1)
function Recurrent:buildInitialModule()self.initialModule = nn.Sequential()self.initialModule:add(self.inputModule:sharedClone())self.initialModule:add(self.startModule)self.initialModule:add(self.transferModule:sharedClone())
end-- build module used for the other steps (steps > 1)
-- 构建整个模型
function Recurrent:buildRecurrentModule()local parallelModule = nn.ParallelTable()parallelModule:add(self.inputModule)parallelModule:add(self.feedbackModule)self.recurrentModule = nn.Sequential()self.recurrentModule:add(parallelModule)self.recurrentModule:add(self.mergeModule)self.recurrentModule:add(self.transferModule)
end-- 更新输出
function Recurrent:updateOutput(input)-- output(t) = transfer(feedback(output_(t-1)) + input(input_(t)))local outputif self.step == 1 thenoutput = self.initialModule:updateOutput(input)elseif self.train ~= false then-- set/save the output statesself:recycle()local recurrentModule = self:getStepModule(self.step)-- self.output is the previous output of this moduleoutput = recurrentModule:updateOutput{input, self.outputs[self.step-1]}else-- self.output is the previous output of this moduleoutput = self.recurrentModule:updateOutput{input, self.outputs[self.step-1]}endendself.outputs[self.step] = outputself.output = outputself.step = self.step + 1self.gradPrevOutput = nilself.updateGradInputStep = nilself.accGradParametersStep = nilreturn self.output
end-- 求解梯度,没有累加
function Recurrent:_updateGradInput(input, gradOutput)assert(self.step > 1, "expecting at least one updateOutput")local step = self.updateGradInputStep - 1local gradInputif self.gradPrevOutput thenself._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)gradOutput = self._gradOutputs[step]endlocal output = self.outputs[step-1]if step > 1 thenlocal recurrentModule = self:getStepModule(step)gradInput, self.gradPrevOutput = unpack(recurrentModule:updateGradInput({input, output}, gradOutput))elseif step == 1 thengradInput = self.initialModule:updateGradInput(input, gradOutput)elseerror"non-positive time-step"endreturn gradInput
end-- 求解梯度,但是会把t steps的梯度相加
function Recurrent:_accGradParameters(input, gradOutput, scale)local step = self.accGradParametersStep - 1local gradOutput = (step == self.step-1) and gradOutput or self._gradOutputs[step]local output = self.outputs[step-1]if step > 1 thenlocal recurrentModule = self:getStepModule(step)recurrentModule:accGradParameters({input, output}, gradOutput, scale)elseif step == 1 thenself.initialModule:accGradParameters(input, gradOutput, scale)elseerror"non-positive time-step"end
endfunction Recurrent:recycle()return parent.recycle(self, 1)
endfunction Recurrent:forget()return parent.forget(self, 1)
endfunction Recurrent:includingSharedClones(f)local modules = self.modulesself.modules = {}local sharedClones = self.sharedClonesself.sharedClones = nillocal initModule = self.initialModuleself.initialModule = nilfor i,modules in ipairs{modules, sharedClones, {initModule}} dofor j, module in pairs(modules) dotable.insert(self.modules, module)endendlocal r = f()self.modules = modulesself.sharedClones = sharedClonesself.initialModule = initModulereturn r
endfunction Recurrent:reinforce(reward)if torch.type(reward) == 'table' then-- multiple rewards, one per time-steplocal rewards = rewardfor step, reward in ipairs(rewards) doif step == 1 thenself.initialModule:reinforce(reward)elselocal sm = self:getStepModule(step)sm:reinforce(reward)endendelse-- one reward broadcast to all time-stepsreturn self:includingSharedClones(function()return parent.reinforce(self, reward)end)end
endfunction Recurrent:maskZero()error("Recurrent doesn't support maskZero as it uses a different ".."module for the first time-step. Use nn.Recurrence instead.")
endfunction Recurrent:trimZero()error("Recurrent doesn't support trimZero as it uses a different ".."module for the first time-step. Use nn.Recurrence instead.")
end-- 把模型打印出来
-- 比如我调用的是:
-- nn.Recurrent(256, nn.Identity(), nn.Linear(256, 256), nn['ReLU'](), 99999)
-- [[[
{input(t), output(t-1)} -> (1) -> (2) -> (3) -> output(t)](1): {input(t)|`-> (t==0): nn.Add|`-> (t~=0): nn.Identityoutput(t-1)|`-> nn.Linear(256 -> 256)}(2): nn.CAddTable(3): nn.ReLU}
---]]
function Recurrent:__tostring__()local tab = ' 'local line = '\n'local next = ' -> 'local str = torch.type(self)str = str .. ' {' .. line .. tab .. '[{input(t), output(t-1)}'for i=1,3 dostr = str .. next .. '(' .. i .. ')'endstr = str .. next .. 'output(t)]'local tab = ' 'local line = '\n 'local next = ' |`-> 'local ext = ' | 'local last = ' ... -> 'str = str .. line .. '(1): ' .. ' {' .. line .. tab .. 'input(t)'str = str .. line .. tab .. next .. '(t==0): ' .. tostring(self.startModule):gsub('\n', '\n' .. tab .. ext)str = str .. line .. tab .. next .. '(t~=0): ' .. tostring(self.inputModule):gsub('\n', '\n' .. tab .. ext)str = str .. line .. tab .. 'output(t-1)'str = str .. line .. tab .. next .. tostring(self.feedbackModule):gsub('\n', line .. tab .. ext)str = str .. line .. "}"local tab = ' 'local line = '\n'local next = ' -> 'str = str .. line .. tab .. '(' .. 2 .. '): ' .. tostring(self.mergeModule):gsub(line, line .. tab)str = str .. line .. tab .. '(' .. 3 .. '): ' .. tostring(self.transferModule):gsub(line, line .. tab)str = str .. line .. '}'return str
end

转载请注明出处: http://blog.csdn.net/c602273091/article/details/78975636

参考链接:
【1】RNN地址: https://github.com/torch/rnn
【2】nn Package: https://arxiv.org/pdf/1511.07889.pdf
【3】RNN Code: https://github.com/torch/rnn/blob/master/doc/recurrent.md#rnn.Recurrence
【4】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua
【5】nn RNN: https://github.com/Element-Research/rnn
【6】Recurrent neural network based language model: http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf
【7】 A guide to recurrent neural networks and backpropagation: http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=CDD081815C5FAC4835EF27B81EEA5F8C?doi=10.1.1.3.9311&rep=rep1&type=pdf
【8】STATISTICAL LANGUAGE MODELS BASED ON NEURAL NETWORKS: (3.2~3.3)http://www.fit.vutbr.cz/%7Eimikolov/rnnlm/thesis.pdf
【9】TRAINING RECURRENT NEURAL NETWORKS:(2.5~2.8) http://www.cs.utoronto.ca/%7Eilya/pubs/ilya_sutskever_phd_thesis.pdf
【10】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua


推荐阅读
  • 在 Vue 应用开发中,页面状态管理和跨页面数据传递是常见需求。本文将详细介绍 Vue Router 提供的两种有效方式,帮助开发者高效地实现页面间的数据交互与状态同步,同时分享一些最佳实践和注意事项。 ... [详细]
  • ButterKnife 是一款用于 Android 开发的注解库,主要用于简化视图和事件绑定。本文详细介绍了 ButterKnife 的基础用法,包括如何通过注解实现字段和方法的绑定,以及在实际项目中的应用示例。此外,文章还提到了截至 2016 年 4 月 29 日,ButterKnife 的最新版本为 8.0.1,为开发者提供了最新的功能和性能优化。 ... [详细]
  • Spring框架中的面向切面编程(AOP)技术详解
    面向切面编程(AOP)是Spring框架中的关键技术之一,它通过将横切关注点从业务逻辑中分离出来,实现了代码的模块化和重用。AOP的核心思想是将程序运行过程中需要多次处理的功能(如日志记录、事务管理等)封装成独立的模块,即切面,并在特定的连接点(如方法调用)动态地应用这些切面。这种方式不仅提高了代码的可维护性和可读性,还简化了业务逻辑的实现。Spring AOP利用代理机制,在不修改原有代码的基础上,实现了对目标对象的增强。 ... [详细]
  • 探索聚类分析中的K-Means与DBSCAN算法及其应用
    聚类分析是一种用于解决样本或特征分类问题的统计分析方法,也是数据挖掘领域的重要算法之一。本文主要探讨了K-Means和DBSCAN两种聚类算法的原理及其应用场景。K-Means算法通过迭代优化簇中心来实现数据点的划分,适用于球形分布的数据集;而DBSCAN算法则基于密度进行聚类,能够有效识别任意形状的簇,并且对噪声数据具有较好的鲁棒性。通过对这两种算法的对比分析,本文旨在为实际应用中选择合适的聚类方法提供参考。 ... [详细]
  • 在《ChartData类详解》一文中,我们将深入探讨 MPAndroidChart 中的 ChartData 类。本文将详细介绍如何设置图表颜色(Setting Colors)以及如何格式化数据值(Formatting Data Values),通过 ValueFormatter 的使用来提升图表的可读性和美观度。此外,我们还将介绍一些高级配置选项,帮助开发者更好地定制和优化图表展示效果。 ... [详细]
  • 在Android开发中,当TextView的高度固定且内容超出时,可以通过设置其内置的滚动条属性来实现垂直滚动功能。具体来说,可以通过配置`android:scrollbars="vertical"`来启用垂直滚动,确保用户能够查看完整的内容。此外,为了优化用户体验,建议结合`setMovementMethod(ScrollerMovementMethod.getInstance())`方法,使滚动操作更加流畅和自然。 ... [详细]
  • 本文深入解析了WCF Binding模型中的绑定元素,详细介绍了信道、信道管理器、信道监听器和信道工厂的概念与作用。从对象创建的角度来看,信道管理器负责信道的生成。具体而言,客户端的信道通过信道工厂进行实例化,而服务端则通过信道监听器来接收请求。文章还探讨了这些组件之间的交互机制及其在WCF通信中的重要性。 ... [详细]
  • 使用 ListView 浏览安卓系统中的回收站文件 ... [详细]
  • 如何在C#中配置组合框的背景颜色? ... [详细]
  • 使用 Vuex 管理表单状态:当输入框失去焦点时自动恢复初始值 ... [详细]
  • 并发编程入门:初探多任务处理技术
    并发编程入门:探索多任务处理技术并发编程是指在单个处理器上高效地管理多个任务的执行过程。其核心在于通过合理分配和协调任务,提高系统的整体性能。主要应用场景包括:1) 将复杂任务分解为多个子任务,并分配给不同的线程,实现并行处理;2) 通过同步机制确保线程间协调一致,避免资源竞争和数据不一致问题。此外,理解并发编程还涉及锁机制、线程池和异步编程等关键技术。 ... [详细]
  • 本课程深入探讨了 Python 中自定义序列类的实现方法,涵盖从基础概念到高级技巧的全面解析。通过实例演示,学员将掌握如何创建支持切片操作的自定义序列对象,并了解 `bisect` 模块在序列处理中的应用。适合希望提升 Python 编程技能的中高级开发者。 ... [详细]
  • 本文探讨了Android系统中支持的图像格式及其在不同版本中的兼容性问题,重点涵盖了存储、HTTP传输、相机功能以及SparseArray的应用。文章详细分析了从Android 10 (API 29) 到Android 11 的存储规范变化,并讨论了这些变化对图像处理的影响。此外,还介绍了如何通过系统升级和代码优化来解决版本兼容性问题,以确保应用程序在不同Android版本中稳定运行。 ... [详细]
  • 本文介绍了如何利用Apache POI库高效读取Excel文件中的数据。通过实际测试,除了分数被转换为小数存储外,其他数据均能正确读取。若在使用过程中发现任何问题,请及时留言反馈,以便我们进行更新和改进。 ... [详细]
  • 本文深入探讨了CGLIB BeanCopier在Bean对象复制中的应用及其优化技巧。相较于Spring的BeanUtils和Apache的BeanUtils,CGLIB BeanCopier在性能上具有显著优势。通过详细分析其内部机制和使用场景,本文提供了多种优化方法,帮助开发者在实际项目中更高效地利用这一工具。此外,文章还讨论了CGLIB BeanCopier在复杂对象结构和大规模数据处理中的表现,为读者提供了实用的参考和建议。 ... [详细]
author-avatar
mr.sun
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有