热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow模型优化训练思路

问题现状随着深度学习模型越来越大,数据集越来越大,模型的训练变得越来越慢。这对于想要快速验证算法的研究人员来说,是个比较麻烦的问题。那

问题现状

随着深度学习模型越来越大,数据集越来越大,模型的训练变得越来越慢。这对于想要快速验证算法的研究人员来说,是个比较麻烦的问题。

那么一般来说,我们会想要优化模型训练,以期更快验证模型效果。

无论是使用Tensorflow还是Pytorch来搭建模型,基本的训练优化思路都是一致的,只是由于框架的不同,某些优化细节有些差别。

这里探讨的是模型在GPU上的训练优化。


基本的训练优化思路


  • Step 1 优化模型在单机单卡上的训练
  • Step 2 优化模型在单机多卡上的训练
  • Step 3 优化模型在多级多卡上的训练

Tensorflow模型的训练优化

针对基本思路,我们发掘一些优化细节。


Step 1 优化模型在单机单卡上的训练

首先,我们需要先将模型在单卡上的训练进行足够的优化,再去谈论扩展GPU数量来提升训练速度。

将数据预处理放在CPU上可以显著提高性能,这样可以让GPU专注训练,使用nvidia-smi来查看GPU的利用率是否达到80%~100%。

对于tensorflow来说,做到以下几点基本可以立马加速模型训练:


  1. 使用Pinned memory
  2. 打开AMP 
  3. 打开XLA (input size需要是固定的)
  4. 使用LAMB作为optmizer(而不用ADAM)
  5. 使用TF32 
  6. 在没有显存溢出的情况下,尝试更大的batch size
  7. 融合op,以减少D2D/H2D/D2H的数据传输 (使用nsight system来profile模型,查看训练瓶颈)
  8. 在GPU 上使用 cuDNN 时,NCHW 数据格式是最优选择。最佳实践是构建同时支持:NCHW/NHWC。
  9. prefetch预取数据 (tf.data API 通过 tf.data.Dataset.prefetch 转换提供了一种软件流水线机制)
  10. Parallel data extraction:(tf.Dataset.interleave(cycle_length=, num_parallel_calls=), cycle_length=多个文件的重合的长度,num_parallel_calls=并行读取的文件数量)

另外,如果是BERT等以Transformer为基础结构的模型,则将模型以Fast transformer来搭建,可以得到更好的训练性能。


Step 2 优化模型在单机多卡上的训练

Horovod是Multi_GPU/Multi_Node训练的首选。


  • 利用Horovod进行Multi_GPU训练。
  • 多卡训练中,使用tf.data API来提供流数据,而不是使用feed_dict.

Step 3 优化模型在多机多卡上的训练

利用Horovod进行Multi_node训练。这部分其实更多的是结合实际问题来进行有针对性的优化。笔者在这方面暂时没有经验,以后希望能更新这部分内容。

 

参考资料

https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT#model-overview

https://developer.nvidia.com/blog/fast-multi-gpu-collectives-nccl/

https://tensorflow.juejin.im/performance/performance_guide.html

NVTX:https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvtx_analysis.htm

XLA:https://tensorflow.juejin.im/performance/xla/index.html

Nsight System:https://developer.nvidia.com/nsight-systems

https://www.cnblogs.com/huangyc/p/10340766.html

https://zhuanlan.zhihu.com/p/163656225

 

 


推荐阅读
author-avatar
曾巧红--------
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有