热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于Tensorflow高阶API构建大规模分布式深度学习模型系列:开篇

Tensorflow高阶API简介在tensorflow高阶API(Estimator、Dataset、Layer、FeatureColumn等)问世之前,用tensorflow开

Tensorflow高阶API简介

在tensorflow高阶API(Estimator、Dataset、Layer、FeatureColumn等)问世之前,用tensorflow开发、训练、评估、部署深度学习模型,并没有统一的规范和高效的标准流程。Tensorflow的实践者们基于低阶API开发的代码在可移植性方面可能会遇到各种困难。例如,单机可以运行的模型希望改成能够分布式环境下运行需要对代码做额外的改动,如果在一个异构的环境中训练模型,则还需要额外花精力处理哪些部分跑在CPU上,哪些部分跑在GPU上。当不同的机器有不同数量的GPU数量时,问题更加复杂。

为了能够快速支持新的网络架构的实验测试,深度学习框架都很重视网络架构搭建的灵活性需求,因此能让用户随心所欲地自定义代码实现是很重要的一块功能。

模型构建的灵活性与简洁性需求看似是矛盾的。从开发者的视角,简洁性意味着当模型架构确定时实现不应该需要太多额外的技能要求,不必对深度学习框架有很深刻的洞察,就能够实验不同的模型特性。在内置简洁性属性的框架下开发者能够较轻松地开发出高质量的鲁棒性较好的模型软件,不会一不小心就踩到坑里。另一方面,灵活性意味着开发者能够实现任意的想要的模型结构,这需要框架能够提供一些相对低价的API。类似于Caffe这样的深度学习框架提供了DSL(domain specific language)来描述模型的结构,虽然搭建已知的成熟的模型架构比较方便,但却不能轻松搭建任意想要的模型结构。这就好比用积木搭建房子,如果现在需要一个特殊的以前没有出现过的积木块以便搭建一个特殊的房子,那就无计可施了。

Tensorflow高阶API正是为了同时满足模型构建的灵活性与简洁性需求应运而生的,它能够让开发者快速搭建出高质量的模型,又能够使用结合低阶API实现不受限制的模型结构。

下面就来看看tensorflow中有哪些常用的高阶API吧。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 高阶API在tensorflow架构中的位置

1. Estimator(估算器)

Estimator类是机器学习模型的抽象,其设计灵感来自于典典大名的Python机器学习库Scikit-learn。Estimator允许开发者自定义任意的模型结构、损失函数、优化方法以及如何对这个模型进行训练、评估和导出等内容,同时屏蔽了与底层硬件设备、分布式网络数据传输等相关的细节。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 Estimator接口

tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
cOnfig=run_config # RunConfig
)

要创建Estimator,需要传入一个模型函数、一组参数和一些配置。

  • 传入的参数应该是模型超参数的一个集合,可以是一个dictionary。
  • 传入的配置用于指定模型如何运行训练和评估,以及在哪里存储结果。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。
  • 模型函数是一个Python函数,它根据给定的输入构建模型。

Estimator类有三个主要的方法:train/fit、evaluate、predict,分别表示模型的训练、评估和预测。三个方法都接受一个用户自定义的输入函数input_fn,执行input_fn获取输入数据。Estimator的这三个方法最终都会调用模型函数(model_fn)执行具体的操作,不同方法被调用时,传递给model_fn的mode参数也是不同的,如下一小节中描述的那样,mode参数是让用户在编写模型函数时知道当前定义的操作是用在模型生命周期的哪一个阶段。

Tensorflow本身还提供了很多内置的开箱即用的Estimator,内置的 Estimator 是 tf.estimator.Estimator 基类的子类,而自定义 Estimator 是 tf.estimator.Estimator 的实例,如下图所示。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 预创建的 Estimator 和自定义 Estimator 都是 Estimator

2. 模型函数

模型函数是用户自定义的一个python函数,它定义了模型训练、评估和预测所需的计算图节点(op)。

模型函数接受输入特征和标签作为参数,同时用mode参数来告知用户模型是在训练、评估或是在执行推理。mode是tf.estimator.ModeKeys对象,它有三个可取的值:TRAIN、EVAL、PREDICT。模型函数的最后一个参数是超参数集合,它们与传递给Estimator的超参数集合相同。模型函数返回一个EstimatorSpec对象,该对象定义了一个完整的模型。EstimatorSpec对象用于对操作进行预测、损失、训练和评估,因此,它定义了一个用于训练、评估和推理的完整的模型图。

一个简单的模型函数示例如下:

def model_fn(features, target, mode, params)
predictiOns= tf.stack(tf.fully_connected, [50, 50, 1])
loss = tf.losses.mean_squared_error(target, predictions)
train_op = tf.train.create_train_op(
loss, tf.train.get_global_step(),
params[’learning_rate’], params[’optimizer’])
return EstimatorSpec(mode=mode,
predictiOns=predictions,
loss=loss,
train_op=train_op)

3. Dataset(数据集)

在tensorflow中,构建模型输入流水线的最佳实践就是使用Dataset API。Dataset API的性能很好,底层使用C++实现,能够绕过python的一些性能限制。

Dataset是对训练、评估、预测阶段所用的数据的抽象表示,其提供了数据读取、解析、打乱(shuffle)、过滤、分批(batch)等操作,是构建模型输入管道的利器,我将会在另外一篇文章《基于Tensorflow高阶API构建大规模分布式深度学习模型系列:基于Dataset API处理Input pipeline》中详细介绍。

4. Feature Columns(特征列)

Feature Columns是特征工程的利器,其能够方便地把原始数据转换为模型的输入数据,并提供了一些常用的数据变换操作,如特征交叉、one-hot编码、embedding编码等。关于Feature Column,也将会在另外一篇文章中详细介绍。

5. Layers

Layer是一组简单的可重复利用的代码,表示神经网络模型中的“层”这个概念。Tensorflow中的layer可以认为是一系列操作(op)的集合,与op一样也是输入tensor并输出tensor的(tensor-in-tensor-out)。Tensorflow中即内置了全连接这样的简单layer,也有像inception网络那样的复杂layer。使用layers来搭建网络模型会更加方便。

6. Head

Head API对网络最后一个隐藏层之后的部分进行了抽象,它的主要设计目标是简化模型函数(model_fn)的编写。Head知道如何计算损失(loss)、评估度量标准(metric)、预测结果(prediction)。为了支持不同的模型,Head接受logits和labels作为参数,并生成表示loss、metric和prediction的张量。有时为了避免计算完整的logit张量,Head也接受最后一个隐藏的激活值作为输入。

一个使用Head简化model_fn编写的例子如下:

def model_fn(features, target, mode, params):
last_layer = tf.stack(tf.fully_connected, [50, 50])
head = tf.multi_class_head(n_classes=10)
return head.create_estimator_spec(
features, mode, last_layer,
label=target,
train_op_fn=lambda loss: my_optimizer.minimize(loss, tf.train.get_global_step())

我们也可以用一个Heads列表来创建一个特殊类型的Head,来完成多目标学习的任务,如下面的例子那样。

def model_fn(features, target, mode, params):
last_layer = tf.stack(tf.fully_connected, [50, 50])
head1 = tf.multi_class_head(n_classes=2,label_name=’y’, head_name=’h1’)
head2 = tf.multi_class_head(n_classes=10,label_name=’z’, head_name=’h2’)
head = tf.multi_head([head1, head2])
return head.create_model_fn_ops(features,
features, mode, last_layer,
label=target,
train_op_fn=lambda loss: my_optimizer.minimize(loss, tf.train.get_global_step())

总结

Tensorflow高阶API简化了模型代码的编写过程,大大降价了新手的入门门槛,使我们能够用一种标准化的方法开发出实验与生产环境部署的代码。使用Tensorflow高阶API能够使我们避免走很多弯路,提高深度学习的实践效率,我们应该尽可能使用高阶API来开发模型。

参考资料

  • TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks
  • 自定义estimators
  • Feature columns

推荐阅读
  • 本文探讨了MariaDB在当前数据库市场中的地位和挑战,分析其可能面临的困境,并提出了对未来发展的几点看法。 ... [详细]
  • 探讨如何真正掌握Java EE,包括所需技能、工具和实践经验。资深软件教学总监李刚分享了对毕业生简历中常见问题的看法,并提供了详尽的标准。 ... [详细]
  • 2018年3月31日,CSDN、火星财经联合中关村区块链产业联盟等机构举办的2018区块链技术及应用峰会(BTA)核心分会场圆满举行。多位业内顶尖专家深入探讨了区块链的核心技术原理及其在实际业务中的应用。 ... [详细]
  • 本文详细介绍了网络存储技术的基本概念、分类及应用场景。通过分析直连式存储(DAS)、网络附加存储(NAS)和存储区域网络(SAN)的特点,帮助读者理解不同存储方式的优势与局限性。 ... [详细]
  • 本文作者分享了在阿里巴巴获得实习offer的经历,包括五轮面试的详细内容和经验总结。其中四轮为技术面试,一轮为HR面试,涵盖了大量的Java技术和项目实践经验。 ... [详细]
  • 数据库内核开发入门 | 搭建研发环境的初步指南
    本课程将带你从零开始,逐步掌握数据库内核开发的基础知识和实践技能,重点介绍如何搭建OceanBase的开发环境。 ... [详细]
  • 探讨一个显示数字的故障计算器,它支持两种操作:将当前数字乘以2或减去1。本文将详细介绍如何用最少的操作次数将初始值X转换为目标值Y。 ... [详细]
  • 本文详细介绍了Java编程语言中的核心概念和常见面试问题,包括集合类、数据结构、线程处理、Java虚拟机(JVM)、HTTP协议以及Git操作等方面的内容。通过深入分析每个主题,帮助读者更好地理解Java的关键特性和最佳实践。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • MySQL缓存机制深度解析
    本文详细探讨了MySQL的缓存机制,包括主从复制、读写分离以及缓存同步策略等内容。通过理解这些概念和技术,读者可以更好地优化数据库性能。 ... [详细]
  • 使用Python在SAE上开发新浪微博应用的初步探索
    最近重新审视了新浪云平台(SAE)提供的服务,发现其已支持Python开发。本文将详细介绍如何利用Django框架构建一个简单的新浪微博应用,并分享开发过程中的关键步骤。 ... [详细]
  • 本文详细探讨了VxWorks操作系统中双向链表和环形缓冲区的实现原理及使用方法,通过具体示例代码加深理解。 ... [详细]
  • Hadoop入门与核心组件详解
    本文详细介绍了Hadoop的基础知识及其核心组件,包括HDFS、MapReduce和YARN。通过本文,读者可以全面了解Hadoop的生态系统及应用场景。 ... [详细]
  • 本文探讨了如何在日常工作中通过优化效率和深入研究核心技术,将技术和知识转化为实际收益。文章结合个人经验,分享了提高工作效率、掌握高价值技能以及选择合适工作环境的方法,帮助读者更好地实现技术变现。 ... [详细]
  • 本文探讨了2012年4月期间,淘宝在技术架构上的关键数据和发展历程。涵盖了从早期PHP到Java的转型,以及在分布式计算、存储和网络流量管理方面的创新。 ... [详细]
author-avatar
静静敲代码
很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有