作者:wb91cmy | 来源:互联网 | 2023-08-16 17:57
1、seq2seqattention的论文出处主要是阅读经典论文《NeuralMachineTranslationbyJointlyLearningtoAlignandTrans
1、seq2seq attention的论文出处
主要是阅读经典论文《Neural Machine Translation by Jointly Learning to Align and Translate》https://arxiv.org/pdf/1409.0473.pdf
这篇论文中机器翻译采用seq2seq的encoder-decoder模型构建,将输入句子encoding成一个固定长度的向量,然后输入到decoder解码生成译文。
attention机制引入要解决的问题
encoder得到的固定向量直接送入decoder,不利于
(1)较长句子的信息传递 (2)提取decoder中目标单词需要关注原文中的语义信息
2、seq2seq decoder的构成
直接截取一下论文中的讲法吧,如下图:
式子(4)的意思就是解码器生成出来的每个token,由
1、就是前一个token的特征向量;
2、 就是当前的隐状态,这里的 也是通过RNN得到的,可以把和看成这一步RNN的输入特征向量;
3、 就是上下文的特征,没有用attention之前就是encoder得到的固定向量。现在用了attention机制这篇文章中经典的attention机制就是(5)式得到的。
关于 的计算中,简单来说就是encoder中各层隐状态 与decoder的隐状态做相似度(可以用一个MLP来实现等,下文再具体说明各种算法),然后求softmax得到权重系数,然后做加权和。当然这只是一种做法,实际怎么做可以有别的做法。
3、seq2seq attention的经典形式
Effective Approaches to Attention-based Neural Machine Translation 论文中,将注意力机制大致分为了全局(global)注意力和局部(local)注意力。
全局注意力指的是注意力分布在所有encoder得到隐状态中;局部注意力指注意力只存在于一些隐状态中。例如,图像的average pooling可视为全局注意力;max pooling可视为局部注意力
论文链接:https://arxiv.org/abs/1508.04025
3.1 global attention
直接上论文中的内容好了,这边global attention的意思和上一篇论文的一样,只是论文中用的符号不一样,应该一看就能知道,实际一个意思。
这里decoder的target hidden和encoder中的source hidden的score怎么算呢?文中介绍三种算法:
这里把score称为基于内容的函数,原因是这里的计算只考虑了隐状态间内容的相关性,并不包含时序信息。
(1)dot:transformer中的Q、K就是用dot计算
(2)general:俗称乘法注意力机制
(3)concat:俗称加法注意力机制
当然文中也提到了location based的global attention不过,似乎(不太确定可能本人孤陋寡闻基本没用过)是历史遗留的产物,大家可以自行阅读。
3.2 local attention
此处,出于这篇文章的完整性,提及一下文章还提到了local attention。提出原因是global attention的计算量大。以机器翻译任务为例,如果原文和译文语种差异较大采用global attention,但如果句法上对应较好,可以采用local attention尝试。
4、pytorch实现Seq2Seq中dot型attention的注意力
这里实现一个3.1中dot型的注意力,输入为encoder的各层隐状态encoder_states以及当前的decoder隐状态decoder_state_t,输出为注意力加权后的上下文状态c
class Seq2SeqAttentionMechanism(nn.Module):def __init__(self):super(Seq2SeqAttentionMechanism, self).__init__()def forward(self, decoder_state_t, encoder_states):bs, source_length, hidden_size = encoder_states.shapedecoder_state_t = decoder_state_t.unsqueeze(1)decoder_state_t = torch.tile(decoder_state_t, dims = (1, source_length, 1))score = torch.sum(decoder_state_t * encoder_states, dim = -1) #[bs, source_length]attn_prob = F.softmax(score, dim = -1) #[bs, source_length]context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1) #[bs, hidden_size]return attn_prob, context