热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorflowEstimator做迁移学习(TransferLearning)

TensorflowEstimator做迁移学习(TransferLearning)-在TensorFlow官方ResNet模型实现分析中我们分析了基于Estimator的模型实现

TensorFlow官方ResNet模型实现分析中我们分析了基于Estimator的模型实现与运行的基本方法。除此之外,这份源码还提供了神经网络中常用的一种手段——迁移学习(Transfer Learning)的实现。

迁移学习

取决于具体的任务,从零开始训练一个深度神经网络有时需要海量的数据才能得到较好的效果。如果你手头的数据有限,又想采用神经网络作为解决方案,可以尝试一下迁移学习。

举一个例子:你负责维护工厂的一条自动化生产线,在传送带上有10种不同的零件随机经过。工业照相机可以逐一捕捉完整的零件图像,但是需要你来根据零件类型调整后续的机械手动作。现在可用于训练的零件图像非常有限,而你手头正好有一个使用大量数据训练好的ImageNet图像分类神经网络模型。如何充分利用这两点是一个典型的迁移学习应用场景。

迁移学习迁移了什么

深度神经网络的结构存在层级。对于卷积神经网络CNN来说,不同层级的卷积层所表现出的特征提取也呈现层级性。具体来说,底层的卷积层对于低阶特征较为敏感,例如边缘、团块等;随着层级的升高,提取的特征越来越抽象。这种随层级变化的特征提取能力是迁移学习的基础。它保证了当任务具备相似性时,例如分类1024种不同的自然物体与分类10种不同的零件,已经训练好的神经网络的特征提取层可以“迁移”到新的分类任务中来继续承担特征提取的功能。

迁移学习的具体的做法

常用的做法包括:

  1. “冻结”特征提取部分。
  2. 使用新数据训练末端负责输出分类的若干全连接层。

TensorFlow如何实现

官方的ResNet模型实现提供了迁移学习的功能。只需要指定--pretrained_model_checkpoint_path--fine_tune 这两个flag就可以实现。

具体到代码中,首先在载入模型时要跳过最终的dense层。

if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')

参数vars_to_warm_start采用正则表达式的方式过滤掉了最后的全连接层。

然后在根据梯度更新参数时,过滤掉不需要更新的部分。

grad_vars = optimizer.compute_gradients(loss)
      if fine_tune:
        grad_vars = _dense_grad_filter(grad_vars)
      minimize_op = optimizer.apply_gradients(grad_vars, global_step)

这里的_dense_grad_filter的实现如下:

def _dense_grad_filter(gvs):
      """Only apply gradient updates to the final layer.

      This function is used for fine tuning.

      Args:
        gvs: list of tuples with gradients and variable info
      Returns:
        filtered gradients so that only the dense layer remains
      """
      return [(g, v) for g, v in gvs if 'dense' in v.name]

这种实现方法是根据node的name属性来实现的。所以在改造网络的时候,注意自己添加的node name不要与之冲突。

参考

迁移学习用于图像识别的Tensorflow实现

https://yinguobing.com/tensorflow-transfer-learning/

tensorflow estimator 使用hook实现finetune

https://github.com/tensorflow/tensorflow/issues/10155

https://medium.com/@utsumuki_neko/using-inception-v3-from-tensorflow-hub-for-transfer-learning-a931ff884526

https://github.com/tensorflow/tensorflow/issues/14713

https://stackoverflow.com/questions/46423956/load-checkpoint-and-finetuning-using-tf-estimator-estimator

TensorFlow如何实现Transfer Learning

TensorFlow 迁移学习识花实战案例

基于Tensorflow高阶API构建大规模分布式深度学习模型系列之自定义Estimator(以文本分类CNN模型为例)


推荐阅读
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • JavaSE笔试题-接口、抽象类、多态等问题解答
    本文解答了JavaSE笔试题中关于接口、抽象类、多态等问题。包括Math类的取整数方法、接口是否可继承、抽象类是否可实现接口、抽象类是否可继承具体类、抽象类中是否可以有静态main方法等问题。同时介绍了面向对象的特征,以及Java中实现多态的机制。 ... [详细]
  • Android系统源码分析Zygote和SystemServer启动过程详解
    本文详细解析了Android系统源码中Zygote和SystemServer的启动过程。首先介绍了系统framework层启动的内容,帮助理解四大组件的启动和管理过程。接着介绍了AMS、PMS等系统服务的作用和调用方式。然后详细分析了Zygote的启动过程,解释了Zygote在Android启动过程中的决定作用。最后通过时序图展示了整个过程。 ... [详细]
  • 本文介绍了C++中省略号类型和参数个数不确定函数参数的使用方法,并提供了一个范例。通过宏定义的方式,可以方便地处理不定参数的情况。文章中给出了具体的代码实现,并对代码进行了解释和说明。这对于需要处理不定参数的情况的程序员来说,是一个很有用的参考资料。 ... [详细]
  • 闭包一直是Java社区中争论不断的话题,很多语言都支持闭包这个语言特性,闭包定义了一个依赖于外部环境的自由变量的函数,这个函数能够访问外部环境的变量。本文以JavaScript的一个闭包为例,介绍了闭包的定义和特性。 ... [详细]
  • JDK源码学习之HashTable(附带面试题)的学习笔记
    本文介绍了JDK源码学习之HashTable(附带面试题)的学习笔记,包括HashTable的定义、数据类型、与HashMap的关系和区别。文章提供了干货,并附带了其他相关主题的学习笔记。 ... [详细]
  • Week04面向对象设计与继承学习总结及作业要求
    本文总结了Week04面向对象设计与继承的重要知识点,包括对象、类、封装性、静态属性、静态方法、重载、继承和多态等。同时,还介绍了私有构造函数在类外部无法被调用、static不能访问非静态属性以及该类实例可以共享类里的static属性等内容。此外,还提到了作业要求,包括讲述一个在网上商城购物或在班级博客进行学习的故事,并使用Markdown的加粗标记和语句块标记标注关键名词和动词。最后,还提到了参考资料中关于UML类图如何绘制的范例。 ... [详细]
  • 基于Socket的多个客户端之间的聊天功能实现方法
    本文介绍了基于Socket的多个客户端之间实现聊天功能的方法,包括服务器端的实现和客户端的实现。服务器端通过每个用户的输出流向特定用户发送消息,而客户端通过输入流接收消息。同时,还介绍了相关的实体类和Socket的基本概念。 ... [详细]
  • 重入锁(ReentrantLock)学习及实现原理
    本文介绍了重入锁(ReentrantLock)的学习及实现原理。在学习synchronized的基础上,重入锁提供了更多的灵活性和功能。文章详细介绍了重入锁的特性、使用方法和实现原理,并提供了类图和测试代码供读者参考。重入锁支持重入和公平与非公平两种实现方式,通过对比和分析,读者可以更好地理解和应用重入锁。 ... [详细]
  • 本文概述了JNI的原理以及常用方法。JNI提供了一种Java字节码调用C/C++的解决方案,但引用类型不能直接在Native层使用,需要进行类型转化。多维数组(包括二维数组)都是引用类型,需要使用jobjectArray类型来存取其值。此外,由于Java支持函数重载,根据函数名无法找到对应的JNI函数,因此介绍了JNI函数签名信息的解决方案。 ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • x86 linux的进程调度,x86体系结构下Linux2.6.26的进程调度和切换
    进程调度相关数据结构task_structtask_struct是进程在内核中对应的数据结构,它标识了进程的状态等各项信息。其中有一项thread_struct结构的 ... [详细]
author-avatar
QEWERTGF_978
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有