热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于Tensorflow高阶API构建大规模分布式深度学习模型系列:开篇

Tensorflow高阶API简介在tensorflow高阶API(Estimator、Dataset、Layer、FeatureColumn等)问世之前,用tensorflow开

Tensorflow高阶API简介

在tensorflow高阶API(Estimator、Dataset、Layer、FeatureColumn等)问世之前,用tensorflow开发、训练、评估、部署深度学习模型,并没有统一的规范和高效的标准流程。Tensorflow的实践者们基于低阶API开发的代码在可移植性方面可能会遇到各种困难。例如,单机可以运行的模型希望改成能够分布式环境下运行需要对代码做额外的改动,如果在一个异构的环境中训练模型,则还需要额外花精力处理哪些部分跑在CPU上,哪些部分跑在GPU上。当不同的机器有不同数量的GPU数量时,问题更加复杂。

为了能够快速支持新的网络架构的实验测试,深度学习框架都很重视网络架构搭建的灵活性需求,因此能让用户随心所欲地自定义代码实现是很重要的一块功能。

模型构建的灵活性与简洁性需求看似是矛盾的。从开发者的视角,简洁性意味着当模型架构确定时实现不应该需要太多额外的技能要求,不必对深度学习框架有很深刻的洞察,就能够实验不同的模型特性。在内置简洁性属性的框架下开发者能够较轻松地开发出高质量的鲁棒性较好的模型软件,不会一不小心就踩到坑里。另一方面,灵活性意味着开发者能够实现任意的想要的模型结构,这需要框架能够提供一些相对低价的API。类似于Caffe这样的深度学习框架提供了DSL(domain specific language)来描述模型的结构,虽然搭建已知的成熟的模型架构比较方便,但却不能轻松搭建任意想要的模型结构。这就好比用积木搭建房子,如果现在需要一个特殊的以前没有出现过的积木块以便搭建一个特殊的房子,那就无计可施了。

Tensorflow高阶API正是为了同时满足模型构建的灵活性与简洁性需求应运而生的,它能够让开发者快速搭建出高质量的模型,又能够使用结合低阶API实现不受限制的模型结构。

下面就来看看tensorflow中有哪些常用的高阶API吧。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 高阶API在tensorflow架构中的位置

1. Estimator(估算器)

Estimator类是机器学习模型的抽象,其设计灵感来自于典典大名的Python机器学习库Scikit-learn。Estimator允许开发者自定义任意的模型结构、损失函数、优化方法以及如何对这个模型进行训练、评估和导出等内容,同时屏蔽了与底层硬件设备、分布式网络数据传输等相关的细节。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 Estimator接口

tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
cOnfig=run_config # RunConfig
)

要创建Estimator,需要传入一个模型函数、一组参数和一些配置。

  • 传入的参数应该是模型超参数的一个集合,可以是一个dictionary。
  • 传入的配置用于指定模型如何运行训练和评估,以及在哪里存储结果。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。
  • 模型函数是一个Python函数,它根据给定的输入构建模型。

Estimator类有三个主要的方法:train/fit、evaluate、predict,分别表示模型的训练、评估和预测。三个方法都接受一个用户自定义的输入函数input_fn,执行input_fn获取输入数据。Estimator的这三个方法最终都会调用模型函数(model_fn)执行具体的操作,不同方法被调用时,传递给model_fn的mode参数也是不同的,如下一小节中描述的那样,mode参数是让用户在编写模型函数时知道当前定义的操作是用在模型生命周期的哪一个阶段。

Tensorflow本身还提供了很多内置的开箱即用的Estimator,内置的 Estimator 是 tf.estimator.Estimator 基类的子类,而自定义 Estimator 是 tf.estimator.Estimator 的实例,如下图所示。

《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》
《基于Tensorflow高阶API构建大规模分布式深度学习模型系列: 开篇》 预创建的 Estimator 和自定义 Estimator 都是 Estimator

2. 模型函数

模型函数是用户自定义的一个python函数,它定义了模型训练、评估和预测所需的计算图节点(op)。

模型函数接受输入特征和标签作为参数,同时用mode参数来告知用户模型是在训练、评估或是在执行推理。mode是tf.estimator.ModeKeys对象,它有三个可取的值:TRAIN、EVAL、PREDICT。模型函数的最后一个参数是超参数集合,它们与传递给Estimator的超参数集合相同。模型函数返回一个EstimatorSpec对象,该对象定义了一个完整的模型。EstimatorSpec对象用于对操作进行预测、损失、训练和评估,因此,它定义了一个用于训练、评估和推理的完整的模型图。

一个简单的模型函数示例如下:

def model_fn(features, target, mode, params)
predictiOns= tf.stack(tf.fully_connected, [50, 50, 1])
loss = tf.losses.mean_squared_error(target, predictions)
train_op = tf.train.create_train_op(
loss, tf.train.get_global_step(),
params[’learning_rate’], params[’optimizer’])
return EstimatorSpec(mode=mode,
predictiOns=predictions,
loss=loss,
train_op=train_op)

3. Dataset(数据集)

在tensorflow中,构建模型输入流水线的最佳实践就是使用Dataset API。Dataset API的性能很好,底层使用C++实现,能够绕过python的一些性能限制。

Dataset是对训练、评估、预测阶段所用的数据的抽象表示,其提供了数据读取、解析、打乱(shuffle)、过滤、分批(batch)等操作,是构建模型输入管道的利器,我将会在另外一篇文章《基于Tensorflow高阶API构建大规模分布式深度学习模型系列:基于Dataset API处理Input pipeline》中详细介绍。

4. Feature Columns(特征列)

Feature Columns是特征工程的利器,其能够方便地把原始数据转换为模型的输入数据,并提供了一些常用的数据变换操作,如特征交叉、one-hot编码、embedding编码等。关于Feature Column,也将会在另外一篇文章中详细介绍。

5. Layers

Layer是一组简单的可重复利用的代码,表示神经网络模型中的“层”这个概念。Tensorflow中的layer可以认为是一系列操作(op)的集合,与op一样也是输入tensor并输出tensor的(tensor-in-tensor-out)。Tensorflow中即内置了全连接这样的简单layer,也有像inception网络那样的复杂layer。使用layers来搭建网络模型会更加方便。

6. Head

Head API对网络最后一个隐藏层之后的部分进行了抽象,它的主要设计目标是简化模型函数(model_fn)的编写。Head知道如何计算损失(loss)、评估度量标准(metric)、预测结果(prediction)。为了支持不同的模型,Head接受logits和labels作为参数,并生成表示loss、metric和prediction的张量。有时为了避免计算完整的logit张量,Head也接受最后一个隐藏的激活值作为输入。

一个使用Head简化model_fn编写的例子如下:

def model_fn(features, target, mode, params):
last_layer = tf.stack(tf.fully_connected, [50, 50])
head = tf.multi_class_head(n_classes=10)
return head.create_estimator_spec(
features, mode, last_layer,
label=target,
train_op_fn=lambda loss: my_optimizer.minimize(loss, tf.train.get_global_step())

我们也可以用一个Heads列表来创建一个特殊类型的Head,来完成多目标学习的任务,如下面的例子那样。

def model_fn(features, target, mode, params):
last_layer = tf.stack(tf.fully_connected, [50, 50])
head1 = tf.multi_class_head(n_classes=2,label_name=’y’, head_name=’h1’)
head2 = tf.multi_class_head(n_classes=10,label_name=’z’, head_name=’h2’)
head = tf.multi_head([head1, head2])
return head.create_model_fn_ops(features,
features, mode, last_layer,
label=target,
train_op_fn=lambda loss: my_optimizer.minimize(loss, tf.train.get_global_step())

总结

Tensorflow高阶API简化了模型代码的编写过程,大大降价了新手的入门门槛,使我们能够用一种标准化的方法开发出实验与生产环境部署的代码。使用Tensorflow高阶API能够使我们避免走很多弯路,提高深度学习的实践效率,我们应该尽可能使用高阶API来开发模型。

参考资料

  • TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks
  • 自定义estimators
  • Feature columns

推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 兆芯X86 CPU架构的演进与现状(国产CPU系列)
    本文详细介绍了兆芯X86 CPU架构的发展历程,从公司成立背景到关键技术授权,再到具体芯片架构的演进,全面解析了兆芯在国产CPU领域的贡献与挑战。 ... [详细]
  • Python 数据可视化实战指南
    本文详细介绍如何使用 Python 进行数据可视化,涵盖从环境搭建到具体实例的全过程。 ... [详细]
  • 从0到1搭建大数据平台
    从0到1搭建大数据平台 ... [详细]
  • 在Windows系统中安装TensorFlow GPU版的详细指南与常见问题解决
    在Windows系统中安装TensorFlow GPU版是许多深度学习初学者面临的挑战。本文详细介绍了安装过程中的每一个步骤,并针对常见的问题提供了有效的解决方案。通过本文的指导,读者可以顺利地完成安装并避免常见的陷阱。 ... [详细]
  • 独家解析:深度学习泛化理论的破解之道与应用前景
    本文深入探讨了深度学习泛化理论的关键问题,通过分析现有研究和实践经验,揭示了泛化性能背后的核心机制。文章详细解析了泛化能力的影响因素,并提出了改进模型泛化性能的有效策略。此外,还展望了这些理论在实际应用中的广阔前景,为未来的研究和开发提供了宝贵的参考。 ... [详细]
  • 利用ZFS和Gluster实现分布式存储系统的高效迁移与应用
    本文探讨了在Ubuntu 18.04系统中利用ZFS和Gluster文件系统实现分布式存储系统的高效迁移与应用。通过详细的技术分析和实践案例,展示了这两种文件系统在数据迁移、高可用性和性能优化方面的优势,为分布式存储系统的部署和管理提供了宝贵的参考。 ... [详细]
  • 本文介绍了如何使用 Google Colab 的免费 GPU 资源进行深度学习应用开发。Google Colab 是一个无需配置即可使用的云端 Jupyter 笔记本环境,支持多种深度学习框架,并且提供免费的 GPU 计算资源。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • 最详尽的4K技术科普
    什么是4K?4K是一个分辨率的范畴,即40962160的像素分辨率,一般用于专业设备居多,目前家庭用的设备,如 ... [详细]
  • 在2019中国国际智能产业博览会上,百度董事长兼CEO李彦宏强调,人工智能应务实推进其在各行业的应用。随后,在“ABC SUMMIT 2019百度云智峰会”上,百度展示了通过“云+AI”推动AI工业化和产业智能化的最新成果。 ... [详细]
  • 在机器学习领域,深入探讨了概率论与数理统计的基础知识,特别是这些理论在数据挖掘中的应用。文章重点分析了偏差(Bias)与方差(Variance)之间的平衡问题,强调了方差反映了不同训练模型之间的差异,例如在K折交叉验证中,不同模型之间的性能差异显著。此外,还讨论了如何通过优化模型选择和参数调整来有效控制这一平衡,以提高模型的泛化能力。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 摩尔线程新款国产显卡曝光:8GB显存,性能超越GTX 660,售价预计超千元 ... [详细]
author-avatar
静静敲代码
很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有