热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

52自然语言处理NLPseq2seqattention的提出计算方式及pytorch实现

1、seq2seqattention的论文出处主要是阅读经典论文《NeuralMachineTranslationbyJointlyLearningtoAlignandTrans

1、seq2seq attention的论文出处

主要是阅读经典论文《Neural Machine Translation by Jointly Learning to Align and Translate》https://arxiv.org/pdf/1409.0473.pdf

这篇论文中机器翻译采用seq2seq的encoder-decoder模型构建,将输入句子encoding成一个固定长度的向量,然后输入到decoder解码生成译文。


attention机制引入要解决的问题

encoder得到的固定向量直接送入decoder,不利于

(1)较长句子的信息传递 (2)提取decoder中目标单词需要关注原文中的语义信息




2、seq2seq decoder的构成

直接截取一下论文中的讲法吧,如下图:

式子(4)的意思就是解码器生成出来的每个token,由

1、y_{i-1}就是前一个token的特征向量;

2、s_{i} 就是当前的隐状态,这里的s_{i} 也是通过RNN得到的,可以把y_{i-1}c_{i}看成这一步RNN的输入特征向量;

3、c_{i} 就是上下文的特征,没有用attention之前就是encoder得到的固定向量。现在用了attention机制这篇文章中经典的attention机制就是(5)式得到的。

关于c_{i} 的计算中,简单来说就是encoder中各层隐状态h_{j} 与decoder的隐状态s_{i-1}做相似度(可以用一个MLP来实现等,下文再具体说明各种算法),然后求softmax得到权重系数\alpha _{ij},然后做加权和。当然这只是一种做法,实际怎么做可以有别的做法。




3、seq2seq attention的经典形式

Effective Approaches to Attention-based Neural Machine Translation 论文中,将注意力机制大致分为了全局(global)注意力和局部(local)注意力。

全局注意力指的是注意力分布在所有encoder得到隐状态中;局部注意力指注意力只存在于一些隐状态中。例如,图像的average pooling可视为全局注意力;max pooling可视为局部注意力

论文链接:https://arxiv.org/abs/1508.04025


3.1 global attention

直接上论文中的内容好了,这边global attention的意思和上一篇论文的一样,只是论文中用的符号不一样,应该一看就能知道,实际一个意思。

这里decoder的target hidden和encoder中的source hidden的score怎么算呢?文中介绍三种算法:

这里把score称为基于内容的函数,原因是这里的计算只考虑了隐状态间内容的相关性,并不包含时序信息。

(1)dot:transformer中的Q、K就是用dot计算

(2)general:俗称乘法注意力机制

(3)concat:俗称加法注意力机制

当然文中也提到了location based的global attention不过,似乎(不太确定可能本人孤陋寡闻基本没用过)是历史遗留的产物,大家可以自行阅读。


3.2 local attention

此处,出于这篇文章的完整性,提及一下文章还提到了local attention。提出原因是global attention的计算量大。以机器翻译任务为例,如果原文和译文语种差异较大采用global attention,但如果句法上对应较好,可以采用local attention尝试。




4、pytorch实现Seq2Seq中dot型attention的注意力

这里实现一个3.1中dot型的注意力,输入为encoder的各层隐状态encoder_states以及当前的decoder隐状态decoder_state_t,输出为注意力加权后的上下文状态c

class Seq2SeqAttentionMechanism(nn.Module):def __init__(self):super(Seq2SeqAttentionMechanism, self).__init__()def forward(self, decoder_state_t, encoder_states):bs, source_length, hidden_size = encoder_states.shapedecoder_state_t = decoder_state_t.unsqueeze(1)decoder_state_t = torch.tile(decoder_state_t, dims = (1, source_length, 1))score = torch.sum(decoder_state_t * encoder_states, dim = -1) #[bs, source_length]attn_prob = F.softmax(score, dim = -1) #[bs, source_length]context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1) #[bs, hidden_size]return attn_prob, context


推荐阅读
author-avatar
wb91cmy
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有