热门标签 | HotTags
当前位置:  开发笔记 > 后端 > 正文

Pytorch之PackedSequence

在训练神经网络时,基本都使用SGD相关算法,网络输入是一个mini-batch的数据。但是NLP领域中不论什么任务,大多面临一个问题:每个样本(文本数据)长度基本不相同。如果你是P

在训练神经网络时,基本都使用SGD相关算法,网络输入是一个mini-batch的数据。但是NLP领域中不论什么任务,大多面临一个问题:每个样本(文本数据)长度基本不相同。

如果你是Pytorch用户,并且要用RNN模型建模,建议你直接使用Pytorch提供的PackSequence来解决这个问题。

torch.nn.utils.rnn.PackedSequence

PackedSequence将长度不同的文本数据封装成一个batch,可以直接作为RNN的输入。既然这么好用,那么如何创建PackedSequence呢?Pytorch提供了pack_padded_sequence()方法,用于创建PackedSequence。

直接看实例吧:

《Pytorch之PackedSequence》
《Pytorch之PackedSequence》

《Pytorch之PackedSequence》
《Pytorch之PackedSequence》

《Pytorch之PackedSequence》
《Pytorch之PackedSequence》
《Pytorch之PackedSequence》
《Pytorch之PackedSequence》
示例代码地址github.com

总结:
如何将一个batch_size的数据封装为RNN可接受的PackedSequence?

  • 将数据按照序列长度由大到小排序
  • lengths=[每条样本的序列长度],T=最大的序列长度,B=batch_size
  • 将数据转为Tensor类型
  • 调用pack_padded_sequence()方法,得到PackedSequence实例


推荐阅读
author-avatar
okkkokkokkkokka
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有