热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

如何利用深度学习模型实现多任务学习?这里有三点经验

 

 

文章转载自机器之心,有做多任务学习相关的同学可以加群709165514讨论学习下,

 

Taboola 算法开发者 Zohar Komarovsky 介绍了他们在利用深度学习模型实现多任务学习(MTL)时遇到的几个典型问题及解决方案。

在过去的一年里,我和我的团队一直致力于为 Taboola feed 提供个性化用户体验。我们运用多任务学习(Multi-Task Learning,MTL),在相同的输入特征集上预测多个关键性能指标(Key Performance Indicator,KPI),然后使用 TensorFlow 实现深度学习模型。回想最初的时候,我们感觉(上手)MTL 比现在要困难很多,所以我希望在这里分享一些经验总结。

现在已经有很多关于利用深度学习模型实现 MTL 的文章。在本文中,我准备分享一些利用神经网络实现 MTL 时需要考虑的具体问题,同时也会展示一些基于 TensorFlow 的简单解决方案。

共享即关怀

我们准备从参数硬共享(hard parameter sharing)的基础方法开始。硬共享表示我们使用一个共享的子网络,下接各个任务特定的子网络。

如何利用深度学习模型实现多任务学习?这里有三点经验

在 TensorFlow 中,实现这样一个模型的简单方法是使用带有 multi_head 的 Estimator。这个模型和其他神经网络架构相比没什么不同,所以你可以自己想想,有哪些可能出错的地方?

第一点:整合损失

我们的 MTL 模型所遇到的第一个挑战是为多个任务定义一个损失函数。既然每个任务都有一个定义良好的损失函数,那么多任务就会有多个损失。

我们尝试的第一个方法是将不同损失简单相加。很快我们就发现,虽然某一个任务会收敛得到不错的结果,其他的却表现很差。进一步研究后,可以很容易地明白原因:不同任务损失的尺度差异非常大,导致整体损失被某一个任务所主导,最终导致其他任务的损失无法影响网络共享层的学习过程。

一个简单的解决方案是,将损失简单相加替换为加权和,以使所有任务损失的尺度接近。但是,这引入了另一个可能需要不时进行调节的超参数。

幸运的是,我们发现了一篇非常棒的论文《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》,提出引入不确定性来确定 MTL 中损失的权重:在每个任务的损失函数中学习另一个噪声参数(noise parameter)。此方法可以接受多任务(可以是回归和分类),并统一所有损失的尺度。这样,我们就能像一开始那样,直接相加得到总损失了。

与损失加权求和相比,该方法不仅得到了更好的结果,而且还可以不再理会额外的权重超参数。论文作者提供的 Keras 实现参见:https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example.ipynb。

第二点:调节学习速率

调节神经网络有一个通用约定:学习速率是最重要的超参数之一。所以我们尝试调节学习速率。我们发现,对于某一个任务 A 而言,存在一个特别适合的学习速率,而对于另一个任务 B,则有不同的适合学习速率。如果选择较高的学习速率,可能在某个任务上出现神经元死亡(由于大的负梯度,导致 Relu 函数永久关闭,即 dying ReLU),而使用较低的学习速率,则会导致其他任务收敛缓慢。应该怎么做呢?我们可以在各个「头部」(见上图,即各任务的子网络)分别调节各自的学习速率,而在共享网络部分,使用另一个学习速率。

虽然听上去很复杂,但其实非常简单。通常,在利用 TensorFlow 训练神经网络时,使用的是:

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

AdamOptimizer 定义如何应用梯度,而 minimize 则完成具体的计算和应用。我们可以将 minimize 替换为我们自己的实现方案,在应用梯度时,为计算图中的各变量使用各自适合的学习速率。

all_variables = shared_vars + a_vars + b_vars
all_gradients = tf.gradients(loss, all_variables)

shared_subnet_gradients = all_gradients[:len(shared_vars)]
a_gradients = all_gradients[len(shared_vars):len(shared_vars + a_vars)]
b_gradients = all_gradients[len(shared_vars + a_vars):]

shared_subnet_optimizer = tf.train.AdamOptimizer(shared_learning_rate)
a_optimizer = tf.train.AdamOptimizer(a_learning_rate)
b_optimizer = tf.train.AdamOptimizer(b_learning_rate)

train_shared_op = shared_subnet_optimizer.apply_gradients(zip(shared_subnet_gradients, shared_vars))
train_a_op = a_optimizer.apply_gradients(zip(a_gradients, a_vars))
train_b_op = b_optimizer.apply_gradients(zip(b_gradients, b_vars))

train_op = tf.group(train_shared_op, train_a_op, train_b_op)    

友情提醒:这个技巧其实在单任务网络中也很实用。

第三点:将估计作为特征

当我们完成了第一阶段的工作,为预测多任务创建好神经网络后,我们可能希望将某一个任务得到的估计(estimate)作为另一个任务的特征。在前向传递(forward-pass)中,这非常简单。估计是一个张量,可以像任意一个神经层的输出一样进行传递。但在反向传播中呢?

假设将任务 A 的估计作为特征输入给 B,我们可能并不希望将梯度从任务 B 传回任务 A,因为我们已经有了任务 A 的标签。对此不用担心,TensorFlow 的 API 所提供的 tf.stop_gradient 会有所帮助。在计算梯度时,它允许你传入一个希望作为常数的张量列表,这正是我们所需要的。

all_gradients = tf.gradients(loss, all_variables, stop_gradients=stop_tensors)    

和上面一样,这个方法对 MTL 网络很有用,但不止如此。该技术可用在任何你希望利用 TensorFlow 计算某个值并将其作为常数的场景。例如,在训练生成对抗网络(Generative Adversarial Network,GAN)时,你不希望将对抗示例反向传播到生成过程中。

下一步

我们的模型已经上线运行,Taboola feed 也已经是个性化的了。但是,还有很多可以提升改进的空间,以及许多有趣的结构可以探索。在我们的应用场景中,预测多任务也意味着基于多个 KPI 完成决策。这可能比基于单个 KPI 的更复杂……不过这就是另一个全新的问题了。

 

原文链接:https://engineering.taboola.com/deep-multi-task-learning-3-lessons-learned/


推荐阅读
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 2018年人工智能大数据的爆发,学Java还是Python?
    本文介绍了2018年人工智能大数据的爆发以及学习Java和Python的相关知识。在人工智能和大数据时代,Java和Python这两门编程语言都很优秀且火爆。选择学习哪门语言要根据个人兴趣爱好来决定。Python是一门拥有简洁语法的高级编程语言,容易上手。其特色之一是强制使用空白符作为语句缩进,使得新手可以快速上手。目前,Python在人工智能领域有着广泛的应用。如果对Java、Python或大数据感兴趣,欢迎加入qq群458345782。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • Android中高级面试必知必会,积累总结
    本文介绍了Android中高级面试的必知必会内容,并总结了相关经验。文章指出,如今的Android市场对开发人员的要求更高,需要更专业的人才。同时,文章还给出了针对Android岗位的职责和要求,并提供了简历突出的建议。 ... [详细]
  • GPT-3发布,动动手指就能自动生成代码的神器来了!
    近日,OpenAI发布了最新的NLP模型GPT-3,该模型在GitHub趋势榜上名列前茅。GPT-3使用的数据集容量达到45TB,参数个数高达1750亿,训练好的模型需要700G的硬盘空间来存储。一位开发者根据GPT-3模型上线了一个名为debuid的网站,用户只需用英语描述需求,前端代码就能自动生成。这个神奇的功能让许多程序员感到惊讶。去年,OpenAI在与世界冠军OG战队的表演赛中展示了他们的强化学习模型,在限定条件下以2:0完胜人类冠军。 ... [详细]
  • OCR:用字符识别方法将形状翻译成计算机文字的过程Matlab:商业数学软件;CUDA:CUDA™是一种由NVIDIA推 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • 本文介绍了PhysioNet网站提供的生理信号处理工具箱WFDB Toolbox for Matlab的安装和使用方法。通过下载并添加到Matlab路径中或直接在Matlab中输入相关内容,即可完成安装。该工具箱提供了一系列函数,可以方便地处理生理信号数据。详细的安装和使用方法可以参考本文内容。 ... [详细]
  • 开发笔记:计网局域网:NAT 是如何工作的?
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了计网-局域网:NAT是如何工作的?相关的知识,希望对你有一定的参考价值。 ... [详细]
  • 本文介绍了RxJava在Android开发中的广泛应用以及其在事件总线(Event Bus)实现中的使用方法。RxJava是一种基于观察者模式的异步java库,可以提高开发效率、降低维护成本。通过RxJava,开发者可以实现事件的异步处理和链式操作。对于已经具备RxJava基础的开发者来说,本文将详细介绍如何利用RxJava实现事件总线,并提供了使用建议。 ... [详细]
  • macOS Big Sur全新设计大版本更新,10+个值得关注的新功能
    本文介绍了Apple发布的新一代操作系统macOS Big Sur,该系统采用全新的界面设计,包括图标、应用界面、程序坞和菜单栏等方面的变化。新系统还增加了通知中心、桌面小组件、强化的Safari浏览器以及隐私保护等多项功能。文章指出,macOS Big Sur的设计与iPadOS越来越接近,结合了去年iPadOS对鼠标的完善等功能。 ... [详细]
  • 安装Tensorflow-GPU文档第一步:通过Anaconda安装python从这个链接https:www.anaconda.comdownload#window ... [详细]
author-avatar
tengfei2008
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有