热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

RNN加法器的坑及对应

模式识别的pj2rnn加法器首先仔细看助教给的大体代码主要要求补齐的是forward函数:看这段代码让我一开始迷惑的问题以及相应解答:1.embedding是什么?embeddin

模式识别的pj2

rnn加法器

首先仔细看助教给的大体代码主要要求补齐的是forward函数:

class myPTRNNModel(nn.Module):
def __init__(self):
super().__init__()
self.embed_layer = nn.Embedding(10, 32)
self.rnn = nn.RNN(64, 64, 2)
self.dense = nn.Linear(64, 10)
def forward(self, num1, num2):
‘‘‘
Please finish your code here.
‘‘‘
num1 = self.embed_layer(num1)
num2 = self.embed_layer(num2)
input = torch.cat((num1, num2), 2)
#packed = pack_padded_sequence(input, encode_length.tolist(), batch_first=True)
r_out, (h_n, h_c) = self.rnn(input, None)
logits = self.dense(r_out)
return logits

看这段代码让我一开始迷惑的问题以及相应解答:



  1. embedding是什么?

    embedding可以看作文字编码的降维,比如onehot编码可以降维到更低



  2. 为什么要用embedding升维?从代码中可以看到,因为只有10个数字,所以为什么要升维到32呢?

    这是因为embedding的又一个作用体现了。对低维的数据进行升维时,可能把一些其他特征给放大了,或者把笼统的特征给分开了



  3. rnn在__init__定义的三个参数是什么?调用的时候又是什么?

    rnn的第一个参数是input_size,也就是输入向量的维度,比如现在的情况,输入的向量是32+32(因为两个要相加的数要串联);

    rnn的第二个参数是hidden_size,是指输出向量的维度,我们这里还是64;

    rnn的第三个参数是num_layer,也就是会进两个rnn层;

    调用的时候,往rnn里扔的两个东西第一个是串联好的input,第二个是hidden_state的初始化,我这里填写了none,就是全部初始化为0,这是最差的一种初始化方式。



  4. logits这里不需要只取最后的时间,因为在后面evaluate的时候单独做了处理。




训练结果

问题是没有办法高位进位,涉及高位进位就会accuracy是0

4位3000轮——0.15,只有5一下 3位3000轮——27.5,只有5以下


修改方向



  1. clipping the gradient

  2. 更改rnn模型:使用lstm,目前的确是越短的加法准确度越高,还可以考虑双向lstm(但是应该没用啊),此外别的加法用了decoder和encoder模型,这样可以解决进位问题吗?

  3. 各种门的初始化使用正交初始化

  4. 先调试训练集的准确度

  5. 即使网络规模小,只有一层rnn,加入dropout和l2正则化都会减轻过拟合

  6. learning rate的选取

  7. 直接串联会不会没有交叉的排列好?


代码里不懂的地方



  1. 有一个处理数据的reverse函数,这是因为加法只会低位影响高位,所以对序列转向,让低位数字先进rnn网络



other工作想法



  1. 看训练集的准确度

  2. 出训练集和测试集的图

  3. 有问题看pytorch的官方文档

  4. 还有相对路径这件破事没搞

  5. pytorch安装的坑:其中,-c pytorch参数指定了conda获取pytorch的channel,在此指定为conda自带的pytorch仓库。因此,只需要将-c pytorch语句去掉,就可以使用清华镜像源快速安装pytorch了。此为pytorch安装的坑

  6. for o in list(zip(datas[2], res))[:20]: print(o[0], o[1], o[0]==o[1])可以方便看训练的具体情况


推荐阅读
  • 第二十五天接口、多态
    1.java是面向对象的语言。设计模式:接口接口类是从java里衍生出来的,不是python原生支持的主要用于继承里多继承抽象类是python原生支持的主要用于继承里的单继承但是接 ... [详细]
  • 解决Bootstrap DataTable Ajax请求重复问题
    在最近的一个项目中,我们使用了JQuery DataTable进行数据展示,虽然使用起来非常方便,但在测试过程中发现了一个问题:当查询条件改变时,有时查询结果的数据不正确。通过FireBug调试发现,点击搜索按钮时,会发送两次Ajax请求,一次是原条件的请求,一次是新条件的请求。 ... [详细]
  • 网站访问全流程解析
    本文详细介绍了从用户在浏览器中输入一个域名(如www.yy.com)到页面完全展示的整个过程,包括DNS解析、TCP连接、请求响应等多个步骤。 ... [详细]
  • 微软推出Windows Terminal Preview v0.10
    微软近期发布了Windows Terminal Preview v0.10,用户可以在微软商店或GitHub上获取这一更新。该版本在2月份发布的v0.9基础上,新增了鼠标输入和复制Pane等功能。 ... [详细]
  • 解决Parallels Desktop错误15265的方法
    本文详细介绍了在使用Parallels Desktop时遇到错误15265的多种解决方案,包括检查网络连接、关闭代理服务器和修改主机文件等步骤。 ... [详细]
  • 自定义滚动条美化页面内容
    当页面内容超出显示范围时,为了提升用户体验和页面美观,通常会添加滚动条。如果默认的浏览器滚动条无法满足设计需求,我们可以自定义一个符合要求的滚动条。本文将详细介绍自定义滚动条的实现过程。 ... [详细]
  • importpymysql#一、直接连接mysql数据库'''coonpymysql.connect(host'192.168.*.*',u ... [详细]
  • Framework7:构建跨平台移动应用的高效框架
    Framework7 是一个开源免费的框架,适用于开发混合移动应用(原生与HTML混合)或iOS&Android风格的Web应用。此外,它还可以作为原型开发工具,帮助开发者快速创建应用原型。 ... [详细]
  • 本文介绍了如何使用 CMD 批处理脚本进行文件操作,包括将指定目录下的 PHP 文件重命名为 HTML 文件,并将这些文件复制到另一个目录。 ... [详细]
  • 两个条件,组合控制#if($query_string~*modviewthread&t(&extra(.*)))?$)#{#set$itid$1;#rewrite^ ... [详细]
  • 本文详细介绍了DMA控制器如何通过映射表处理来自外设的请求,包括映射表的设计和实现方法。 ... [详细]
  • 本文介绍了如何利用HTTP隧道技术在受限网络环境中绕过IDS和防火墙等安全设备,实现RDP端口的暴力破解攻击。文章详细描述了部署过程、攻击实施及流量分析,旨在提升网络安全意识。 ... [详细]
  • 本文详细介绍了如何利用Duilib界面库开发窗体动画效果,包括基本思路和技术细节。这些方法不仅适用于Duilib,还可以扩展到其他类似的界面开发工具。 ... [详细]
  • Spark中使用map或flatMap将DataSet[A]转换为DataSet[B]时Schema变为Binary的问题及解决方案
    本文探讨了在使用Spark的map或flatMap算子将一个数据集转换为另一个数据集时,遇到的Schema变为Binary的问题,并提供了详细的解决方案。 ... [详细]
  • 解决 Windows Server 2016 网络连接问题
    本文详细介绍了如何解决 Windows Server 2016 在使用无线网络 (WLAN) 和有线网络 (以太网) 时遇到的连接问题。包括添加必要的功能和安装正确的驱动程序。 ... [详细]
author-avatar
real存在尹
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有