热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

为SparkDeepLearning添加NLP处理实现

前言前段时间研究了SDL项目,看到了Spark的宏大愿景,写了篇Spark新愿景:让深度学习变得更加易于使用。后面看了TFoS,感觉很是巧妙,写了一篇TensorFlowOnSpa
前言

前段时间研究了SDL项目,看到了Spark的宏大愿景,写了篇Spark新愿景:让深度学习变得更加易于使用。后面看了TFoS,感觉很是巧妙,写了一篇TensorFlowOnSpark 源码解析。这些项目都得益于Spark对python的支持,所以了解了下spark和python如何进行交互的,可参看此文PySpark如何设置worker的python命令。

虽然非常看好SDL,但是它存在几个明显的问题:

  1. 进度慢的让人难以忍受。截止到目前为止,已经有26天没有新commit了。
  2. 只做了图像相关的工作,没有任何NLP相关的工具使用。 参看其他人提的这个Issue: What would it take to generalize to non-image data?
  3. 现有的分布式调参功能,基本不可用。参看我提的这个Issue: To Avoid collecting trainning data to driver and broadcasting them
  4. 不支持分布式tranning. 参看我提的这个Issue: Is there any plan to port TensorframeOnSpark(From yahoo) ?

当然SDL的想法非常好:

  1. 相比K8s + TF只是完成了分布式训练, SDL 把data process ,data training,data inference 三者给完全衔接了。
  2. 提供了一个很好的编程模型,以sk-learn/Mllib的方式完成模型的训练,对于工作效率提升明显。
  3. 分布式模型训练,分布式模型超参数tunning, 分别解决了训练数据量大的问题,参数探索的问题。

因为我司以NLP为主,所以我提供了一个deep learning auto-encoder的一个demo,展现SDL的能力。顺带通过引入Kafka解决了
“分布式模型超参数tunning”在实际场景不可用的问题。有时间会完成和TFoS的集成。

演示代码

我这里写了一个单元测试(python/tests/transformers/tf_text_test.py):

class TFTextTransformerTest(SparkDLTestCase):
def test_loadText(self):
input_col = "text"
output_col = "sentence_matrix"
documentDF = self.session.createDataFrame([
("Hi I heard about Spark", 1),
("I wish Java could use case classes", 0),
("Logistic regression models are neat", 2)
], ["text", "preds"])
# transform text column to sentence_matrix column which contains 2-D array.
transformer = TFTextTransformer(
inputCol=input_col, outputCol=output_col)
df = transformer.transform(documentDF)
# create a estimator to training where map_fun contains tensorflow's code
estimator = TFTextFileEstimator(inputCol="sentence_matrix", outputCol="sentence_matrix", labelCol="preds",
kafkaParam={"host": "127.0.0.1", "topic": "test", "group_id": "sdl_1"},
fitParam=[{"epochs": 5, "batch_size": 64}, {"epochs": 5, "batch_size": 1}],
mapFnParam=map_fun)
estimator.fit(df).collect()

TFTextTransformer 主要是把任意文本转化为一个二维矩阵,一行代表一个词汇,每个词汇都是word embedding的形态。该Transformer本质是做featurize的工作,2-D array 是能够直接被包括CNN,LSTM等算法操作的格式。 我这里简要介绍下TFTextTransformer的处理流程:

  1. 获取输入列,然后使用word2vec对数据进行训练,得到每个词的word embedding,最后作为一个map(word, vector) 广播出去
  2. 将input_col列的句子转化为一个2-D array作为outputCol
  3. 添加一些常数列到新的DataFrame里,比如vocab_size(词汇数目),embedding_size(词向量大小)。
  4. 返回新DataFrame

TFTextFileEstimator 完成训练过程,具体流程为:

  1. TFTextFileEstimator 将TFTextTransformer的每一条数据序列化后写入Kafka
  2. 根据fitParams (也就是你设置的超参数组合)长度,启动对应个数的tensorflow实例
  3. 为tensorflow实例从kafka拉去数据,并且提供一个_read_data函数句柄给tensorflow程序。
  4. 调用你编写的tf程序,完成训练。

额外引入kafka的原因是因为,每个tensorflow实例都需要消费全量的数据,一个简单的做法是把数据collect到driver端然后broadcast出去,但是实际上行不通,所以将数据集中放在kafka。

map_fun 是一个函数,这里你完全可以使用keras/tensorflow 构建模型,并且调用_read_data获取数据,以及通过args获得必要的参数,具体代码(python/sparkdl/tf_fun.py):

def map_fun(_read_data, **args):
import tensorflow as tf
EMBEDDING_SIZE = args["embedding_size"]
feature = args['feature']
label = args['label']
params = args['params']['fitParam']
SEQUENCE_LENGTH = 64
def feed_dict(batch):
# Convert from dict of named arrays to two numpy arrays of the proper type
features = []
for i in batch:
features.append(i['sentence_matrix'])
# print("{} {}".format(feature, features))
return features
encoder_variables_dict = {
"encoder_w1": tf.Variable(
tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE, 256]), name="encoder_w1"),
"encoder_b1": tf.Variable(tf.random_normal([256]), name="encoder_b1"),
"encoder_w2": tf.Variable(tf.random_normal([256, 128]), name="encoder_w2"),
"encoder_b2": tf.Variable(tf.random_normal([128]), name="encoder_b2")
}

_read_data 可以获取spark dataframe的数据,典型用法如下:

for i in range(params.epochs):
print("epoll {}".format(i))
for data in _read_data(max_records=params.batch_size):
batch_data = feed_dict(data)
sess.run(train_step, feed_dict={input_x: batch_data})
sess.close()

这里,你核心关注如何构建网络,数据处理的工作前面的transformer已经帮你完成。

详细代码参看: https://github.com/allwefantasy/spark-deep-learning/tree/nlp-support


推荐阅读
  • 「爆干7天7夜」入门AI人工智能学习路线一条龙,真的不能再透彻了
    前言应广大粉丝要求,今天迪迦来和大家讲解一下如何去入门人工智能,也算是迪迦对自己学习人工智能这么多年的一个总结吧,本条学习路线并不会那么 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • bat大牛带你深度剖析android 十大开源框架_请收好!5大领域,21个必知的机器学习开源工具...
    全文共3744字,预计学习时长7分钟本文将介绍21个你可能没使用过的机器学习开源工具。每个开源工具都为数据科学家处理数据库提供了不同角度。本文将重点介绍五种机器学习的 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 使用圣杯布局模式实现网站首页的内容布局
    本文介绍了使用圣杯布局模式实现网站首页的内容布局的方法,包括HTML部分代码和实例。同时还提供了公司新闻、最新产品、关于我们、联系我们等页面的布局示例。商品展示区包括了车里子和农家生态土鸡蛋等产品的价格信息。 ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • 干货 | 携程AI推理性能的自动化优化实践
    作者简介携程度假AI研发团队致力于为携程旅游事业部提供丰富的AI技术产品,其中性能优化组为AI模型提供全方位的优化方案,提升推理性能降低成本࿰ ... [详细]
  • 软件测试工程师,需要达到什么水平才能顺利拿到 20k+ 无压力?
    前言最近看到很多应届生晒offer,稍有名气点的公司给出的价格都是一年30多W或者月薪20几k,相比之下工作几年的自己薪资确实很寒酸.根据我自己找工作经历,二线城市一般小公司招聘 ... [详细]
  • SLAM优秀开源工程最全汇总
    https:zhuanlan.zhihu.comp145750808 1、CartographerCartographer是一个系统,可跨多个平台和传感器配置以2D和3D形式提供实 ... [详细]
  • qt学习(六)数据库注册用户的实现方法
    本文介绍了在qt学习中实现数据库注册用户的方法,包括登录按钮按下后出现注册页面、账号可用性判断、密码格式判断、邮箱格式判断等步骤。具体实现过程包括UI设计、数据库的创建和各个模块调用数据内容。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 计算机存储系统的层次结构及其优势
    本文介绍了计算机存储系统的层次结构,包括高速缓存、主存储器和辅助存储器三个层次。通过分层存储数据可以提高程序的执行效率。计算机存储系统的层次结构将各种不同存储容量、存取速度和价格的存储器有机组合成整体,形成可寻址存储空间比主存储器空间大得多的存储整体。由于辅助存储器容量大、价格低,使得整体存储系统的平均价格降低。同时,高速缓存的存取速度可以和CPU的工作速度相匹配,进一步提高程序执行效率。 ... [详细]
author-avatar
i_Screw_Robots
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有