热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【序列到序列学习】使用ScheduledSampling改善翻译质量

导语PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式构建起千变万化的深度学习模型来解决不同的应用问题。这里,我们针对常见的机器学习任

导语

PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式构建起千变万化的深度学习模型来解决不同的应用问题。这里,我们针对常见的机器学习任务,提供了不同的神经网络模型供大家学习和使用。本周推文目录如下:

3.12:【命名实体识别】

训练端到端的序列标注模型

3.13:【序列到序列学习】

无注意力机制的神经机器翻译

3.14:【序列到序列学习】

使用Scheduled Sampling改善翻译质量

3.15:【序列到序列学习】

带外部记忆机制的神经机器翻译

3.16:【序列到序列学习】

 生成古诗词

序列到序列学习实现两个甚至是多个不定长模型之间的映射,有着广泛的应用,包括:机器翻译、智能对话与问答、广告创意语料生成、自动编码(如金融画像编码)、判断多个文本串之间的语义相关性等。

在序列到序列学习任务中,我们首先以机器翻译任务为例,提供了多种改进模型供大家学习和使用。包括:不带注意力机制的序列到序列映射模型,这一模型是所有序列到序列学习模型的基础;使用Scheduled Sampling改善RNN模型在生成任务中的错误累积问题;带外部记忆机制的神经机器翻译,通过增强神经网络的记忆能力,来完成复杂的序列到序列学习任务。除机器翻译任务之外,我们也提供了一个基于深层LSTM网络生成古诗词,实现同语言生成的模型。

 【序列到序列学习】 

02

使用Scheduled Sampling

改善翻译质量

|1. 概述

序列生成任务的生成目标是在给定源输入的条件下,最大化目标序列的概率。训练时该模型将目标序列中的真实元素作为解码器每一步的输入,然后最大化下一个元素的概率。生成时上一步解码得到的元素被用作当前的输入,然后生成下一个元素。可见这种情况下训练阶段和生成阶段的解码器输入数据的概率分布并不一致。

Scheduled Sampling [1]是一种解决训练和生成时输入数据分布不一致的方法。在训练早期该方法主要使用目标序列中的真实元素作为解码器输入,可以将模型从随机初始化的状态快速引导至一个合理的状态。随着训练的进行,该方法会逐渐更多地使用生成的元素作为解码器输入,以解决数据分布不一致的问题。

标准的序列到序列模型中,如果序列前面生成了错误的元素,后面的输入状态将会收到影响,而该误差会随着生成过程不断向后累积。Scheduled Sampling以一定概率将生成的元素作为解码器输入,这样即使前面生成错误,其训练目标仍然是最大化真实目标序列的概率,模型会朝着正确的方向进行训练。因此这种方式增加了模型的容错能力

|2. 算法简介

Scheduled Sampling主要应用在序列到序列模型的训练阶段,而生成阶段则不需要使用。

训练阶段解码器在最大化第t个元素概率时,标准序列到序列模型使用上一时刻的真实元素yt−1作为输入。设上一时刻生成的元素为gt−1,Scheduled Sampling算法会以一定概率使用gt−1作为解码器输入。

设当前已经训练到了第i个mini-batch,Scheduled Sampling定义了一个概率ϵi控制解码器的输入。ϵi是一个随着i增大而衰减的变量,常见的定义方式有:

  • 线性衰减:ϵi=max(ϵ,k−c∗i),其中ϵ限制ϵi的最小值,k和c控制线性衰减的幅度。

  • 指数衰减:ϵi=ki,其中0

  • 反向Sigmoid衰减:ϵi=k/(k+exp(i/k)),其中k>1,k同样控制衰减的幅度。

图1给出了这三种方式的衰减曲线,

图1. 线性衰减、指数衰减和

反向Sigmoid衰减的衰减曲线

如图2所示,在解码器的t时刻Scheduled Sampling以概率ϵi使用上一时刻的真实元素yt−1作为解码器输入,以概率1−ϵi使用上一时刻生成的元素gt−1作为解码器输入。从图1可知随着i的增大ϵi会不断减小,解码器将不断倾向于使用生成的元素作为输入,训练阶段和生成阶段的数据分布将变得越来越一致。

图2. Scheduled Sampling选择不同元素作为解码器输入示意图

|3. 模型实现

由于Scheduled Sampling是对序列到序列模型的改进,其整体实现框架与序列到序列模型较为相似。为突出本文重点,这里仅介绍与Scheduled Sampling相关的部分,完整的代码见network_conf.py。

首先导入需要的包,并定义控制衰减概率的类RandomScheduleGenerator,如下:

import numpy as np

import math

class RandomScheduleGenerator:

    """

    The random sampling rate for scheduled sampling algoithm, which uses devcayed

    sampling rate.

    """

    ...

下面将分别定义类RandomScheduleGenerator的__init__、getScheduleRate和processBatch三个方法。

__init__方法对类进行初始化,其schedule_type参数指定了使用哪种衰减方式,可选的方式有constant、linear、exponential和inverse_sigmoid。constant指对所有的mini-batch使用固定的ϵi,linear指线性衰减方式,exponential表示指数衰减方式,inverse_sigmoid表示反向Sigmoid衰减。__init__方法的参数a和b表示衰减方法的参数,需要在验证集上调优。self.schedule_computers将衰减方式映射为计算ϵi的函数。最后一行根据schedule_type将选择的衰减函数赋给self.schedule_computer变量。

def __init__(self, schedule_type, a, b):

    """

    schduled_type: is the type of the decay. It supports constant, linear,

    exponential, and inverse_sigmoid right now.

    a: parameter of the decay (MUST BE DOUBLE)

    b: parameter of the decay (MUST BE DOUBLE)

    """

    self.schedule_type = schedule_type

    self.a = a

    self.b = b

    self.data_processed_ = 0

    self.schedule_computers = {

        "constant": lambda a, b, d: a,

        "linear": lambda a, b, d: max(a, 1 - d / b),

        "exponential": lambda a, b, d: pow(a, d / b),

        "inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),

    }

    assert (self.schedule_type in self.schedule_computers)

    self.schedule_computer = self.schedule_computers[self.schedule_type]

getScheduleRate根据衰减函数和已经处理的数据量计算ϵi。

def getScheduleRate(self):
   """    Get the schedule sampling rate. Usually not needed to be called by the users    """
   return self.schedule_computer(self.a, self.b, self.data_processed_)

processBatch方法根据概率值ϵi进行采样,得到indexes,indexes中每个元素取值为0的概率为ϵi,取值为1的概率为1−ϵi。indexes决定了解码器的输入是真实元素还是生成的元素,取值为0表示使用真实元素,取值为1表示使用生成的元素。

def processBatch(self, batch_size):

    """

    Get a batch_size of sampled indexes. These indexes can be passed to a

    MultiplexLayer to select from the grouth truth and generated samples

    from the last time step.

    """

    rate = self.getScheduleRate()

    numbers = np.random.rand(batch_size)

    indexes = (numbers >= rate).astype('int32').tolist()

    self.data_processed_ += batch_size

    return indexes

Scheduled Sampling需要在序列到序列模型的基础上增加一个输入true_token_flag,以控制解码器输入。

true_token_flags = paddle.layer.data(

    name='true_token_flag',

type=paddle.data_type.integer_value_sequence(2))

这里还需要对原始reader进行封装,增加true_token_flag的数据生成器。下面以线性衰减为例说明如何调用上面定义的RandomScheduleGenerator产生true_token_flag的输入数据。

def gen_schedule_data(reader,

                      schedule_type="linear",

                      decay_a=0.75,

                      decay_b=1000000):

    """

    Creates a data reader for scheduled sampling.

    Output from the iterator that created by original reader will be

    appended with "true_token_flag" to indicate whether to use true token.

    :param reader: the original reader.

    :type reader: callable

    :param schedule_type: the type of sampling rate decay.

    :type schedule_type: str

    :param decay_a: the decay parameter a.

    :type decay_a: float

    :param decay_b: the decay parameter b.

    :type decay_b: float

    :return: the new reader with the field "true_token_flag".

    :rtype: callable

    """

    schedule_generator = RandomScheduleGenerator(schedule_type, decay_a, decay_b)

    def data_reader():

        for src_ids, trg_ids, trg_ids_next in reader():

            yield src_ids, trg_ids, trg_ids_next, \

                  [0] + schedule_generator.processBatch(len(trg_ids) - 1)

    return data_reader

这段代码在原始输入数据(即源序列元素src_ids、目标序列元素trg_ids和目标序列下一个元素trg_ids_next)后追加了控制解码器输入的数据。由于解码器第一个元素是序列开始符,因此将追加的数据第一个元素设置为0,表示解码器第一步始终使用真实目标序列的第一个元素(即序列开始符)。

训练时recurrent_group每一步调用的解码器函数如下:

def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,

                                       true_token_flag):

      """

      The decoder step for training.

      :param enc_vec: the encoder vector for attention

      :type enc_vec: LayerOutput

      :param enc_proj: the encoder projection for attention

      :type enc_proj: LayerOutput

      :param true_word: the ground-truth target word

      :type true_word: LayerOutput

      :param true_token_flag: the flag of using the ground-truth target word

      :type true_token_flag: LayerOutput

      :return: the softmax output layer

      :rtype: LayerOutput

      """

      decoder_mem = paddle.layer.memory(

          name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)

      context = paddle.networks.simple_attention(

          encoded_sequence=enc_vec,

          encoded_proj=enc_proj,

          decoder_state=decoder_mem)

      gru_out_memory = paddle.layer.memory(

          name='gru_out', size=target_dict_dim)

      generated_word = paddle.layer.max_id(input=gru_out_memory)

      generated_word_emb = paddle.layer.embedding(

          input=generated_word,

          size=word_vector_dim,

          param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))

      current_word = paddle.layer.multiplex(

          input=[true_token_flag, true_word, generated_word_emb])

      decoder_inputs = paddle.layer.fc(

          input=[context, current_word],

          size=decoder_size * 3,

          act=paddle.activation.Linear(),

          bias_attr=False)

      gru_step = paddle.layer.gru_step(

          name='gru_decoder',

          input=decoder_inputs,

          output_mem=decoder_mem,

          size=decoder_size)

      out = paddle.layer.fc(

          name='gru_out',

          input=gru_step,

          size=target_dict_dim,

          act=paddle.activation.Softmax())

      return out

该函数使用memory层gru_out_memory记忆上一时刻生成的元素,根据gru_out_memory选择概率最大的词语generated_word作为生成的词语。multiplex层会在真实元素true_word和生成的元素generated_word之间做出选择,并将选择的结果作为解码器输入。multiplex层使用了三个输入,分别为true_token_flag、true_word和generated_word_emb。对于这三个输入中每个元素,若true_token_flag中的值为0,则multiplex层输出true_word中的相应元素;若true_token_flag中的值为1,则multiplex层输出generated_word_emb中的相应元素。

【参考文献】

  1. Bengio S, Vinyals O, Jaitly N, et al. Scheduled sampling for sequence prediction with recurrent neural networks//Advances in Neural Information Processing Systems. 2015: 1171-1179.

今 日 AI 资 讯 

(如欲了解详情,在后台回复当日日期数字,例如“314”!)

1.普华永道发布人工智能报告,对人工智能在2018年的发展做出了8项预测。(AI视点)

2.上交大发布知识图谱AceKG,超1亿实体,近100G数据。(新智元)

3.第四范式业界首推免费智能客服服务。(机器之心)

 end

*原创贴,版权所有,未经许可,禁止转载

*值班小Paddle:wangp

*欢迎在留言区分享您的观点

*为了方便大家问题的跟进解决,我们采用Github Issue来采集信息和追踪进度。大家遇到问题请搜索Github Issue,问题未解决请优先在Github Issue上提问,有助于问题的积累和沉淀

点击“阅读原文”,访问Github Issue。


推荐阅读
  • 本文详细介绍了如何在Linux系统中搭建51单片机的开发与编程环境,重点讲解了使用Makefile进行项目管理的方法。首先,文章指导读者安装SDCC(Small Device C Compiler),这是一个专为小型设备设计的C语言编译器,适合用于51单片机的开发。随后,通过具体的实例演示了如何配置Makefile文件,以实现代码的自动化编译与链接过程,从而提高开发效率。此外,还提供了常见问题的解决方案及优化建议,帮助开发者快速上手并解决实际开发中可能遇到的技术难题。 ... [详细]
  • 在Unity中进行3D建模的全面指南,详细介绍了市场上三种主要的3D建模工具:Blender 3D、Maya和3ds Max。每种工具的特点、优势及其在Unity开发中的应用将被深入探讨,帮助开发者选择最适合自己的建模软件。 ... [详细]
  • MySQL性能优化与调参指南【数据库管理】
    本文详细探讨了MySQL数据库的性能优化与参数调整技巧,旨在帮助数据库管理员和开发人员提升系统的运行效率。内容涵盖索引优化、查询优化、配置参数调整等方面,结合实际案例进行深入分析,提供实用的操作建议。此外,还介绍了常见的性能监控工具和方法,助力读者全面掌握MySQL性能优化的核心技能。 ... [详细]
  • Go语言实现Redis客户端与服务器的交互机制深入解析
    在前文对Godis v1.0版本的基础功能进行了详细介绍后,本文将重点探讨如何实现客户端与服务器之间的交互机制。通过具体代码实现,使客户端与服务器能够顺利通信,赋予项目实际运行的能力。本文将详细解析Go语言在实现这一过程中的关键技术和实现细节,帮助读者深入了解Redis客户端与服务器的交互原理。 ... [详细]
  • Prim算法在处理稠密图时表现出色,尤其适用于边数远多于顶点数的情形。传统实现的时间复杂度为 \(O(n^2)\),但通过引入优先队列进行优化,可以在点数为 \(m\)、边数为 \(n\) 的情况下显著降低时间复杂度,提高算法效率。这种优化方法不仅能够加速最小生成树的构建过程,还能在大规模数据集上保持良好的性能表现。 ... [详细]
  • 在进行网络编程时,准确获取本地主机的IP地址是一项基本但重要的任务。Winsock作为20世纪90年代初由Microsoft与多家公司共同制定的Windows平台网络编程接口,为开发者提供了一套高效且易用的工具。通过Winsock,开发者可以轻松实现网络通信功能,并准确获取本地主机的IP地址,从而确保应用程序在网络环境中的稳定运行。此外,了解Winsock的工作原理及其API函数的使用方法,有助于提高开发效率和代码质量。 ... [详细]
  • 本文详细介绍了使用响应文件在静默模式下安装和配置Oracle 11g的方法。硬件要求包括:内存至少1GB,具体可通过命令`grep -i memtotal /proc/meminfo`进行检查。此外,还提供了详细的步骤和注意事项,确保安装过程顺利进行。 ... [详细]
  • 在稀疏直接法视觉里程计中,通过优化特征点并采用基于光度误差最小化的灰度图像线性插值技术,提高了定位精度。该方法通过对空间点的非齐次和齐次表示进行处理,利用RGB-D传感器获取的3D坐标信息,在两帧图像之间实现精确匹配,有效减少了光度误差,提升了系统的鲁棒性和稳定性。 ... [详细]
  • 本文探讨了在Android应用中实现动态滚动文本显示控件的优化方法。通过详细分析焦点管理机制,特别是通过设置返回值为`true`来确保焦点不会被其他控件抢占,从而提升滚动文本的流畅性和用户体验。具体实现中,对`MarqueeText.java`进行了代码层面的优化,增强了控件的稳定性和兼容性。 ... [详细]
  • 本文详细解析了 MySQL 5.7.20 版本中二进制日志(binlog)崩溃恢复机制的工作流程。假设使用 InnoDB 存储引擎,并且启用了 `sync_binlog=1` 配置,文章深入探讨了在系统崩溃后如何通过 binlog 进行数据恢复,确保数据的一致性和完整性。 ... [详细]
  • 深入解析 Vue.js 的设计与实现:第三章详解
    在《深入解析 Vue.js 的设计与实现》第三章中,详细探讨了 Vue.js 渲染器与虚拟 DOM 的机制。通过 JavaScript 对象来模拟实际的 DOM 结构,例如,`const vNode = { tag: 'div', props: { ... } }`,这种方式不仅提高了性能,还增强了组件的可维护性和灵活性。本章进一步分析了虚拟 DOM 的创建、更新及优化策略,为开发者提供了深入了解 Vue.js 内核工作的视角。 ... [详细]
  • 根据不同环境需求,利用 Vue CLI 的 `npm run build` 命令对项目进行定制化打包,如测试、预发布和生产环境。通过配置 `process.env` 变量,实现不同环境下接口和服务的动态切换,确保应用在各阶段都能高效运行和调试。 ... [详细]
  • 开发心得:利用 Redis 构建分布式系统的轻量级协调机制
    开发心得:利用 Redis 构建分布式系统的轻量级协调机制 ... [详细]
  • 在Windows命令行中,通过Conda工具可以高效地管理和操作虚拟环境。具体步骤包括:1. 列出现有虚拟环境:`conda env list`;2. 创建新虚拟环境:`conda create --name 环境名`;3. 删除虚拟环境:`conda env remove --name 环境名`。这些命令不仅简化了环境管理流程,还提高了开发效率。此外,Conda还支持环境文件导出和导入,方便在不同机器间迁移配置。 ... [详细]
  • 在Python 3环境中,当无法连接互联网时,可以通过下载离线模块包来实现模块的安装。具体步骤包括:首先从PyPI网站下载所需的模块包,然后将其传输到目标环境,并使用`pip install`命令进行本地安装。此方法不仅适用于单个模块,还支持依赖项的批量安装,确保开发环境的完整性和一致性。 ... [详细]
author-avatar
洛特大人_382
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有