热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Deeplearning4j手写体数字识别

2019独角兽企业重金招聘Python工程师标准最近这几年,深度学习很火,包括自己在内的很多对机器学习还是一知半解的小白也开始用深度学习做些应用。

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

最近这几年,深度学习很火,包括自己在内的很多对机器学习还是一知半解的小白也开始用深度学习做些应用。由于小白的等级不高,算法自己写不出来,所以就用了开源库。Deep Learning的开源库有多,如果以语言来划分的话,就有Python系列的tensowflow,theano,keras,C/C++系列的Caffe,还有Lua系列的torch等等。但咱们公司是用Java为主,大部分项目最终也是做成一个Java Web的服务,所以我最终选择了Deeplearning4j。

    Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持Spark。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。

    一般开始使用别人的代码库,都会先跑一些demo,或者说Hello World的例子,就好像学习一门编程语言一样,第一行代码都是打印Hello World。Deep Learning的Hello World的例子一般是两个,一个是Mnist数据集的分类,另一个就是Word2Vec找相似词。由于Word2Vec并不是严格意义上的深度神经网络,因此这里就用Lenet网络处理Mnist数据集来作为Deep Learning的Hello World。Mnist是开源的28x28的黑白手写体数字图片集(http://yann.lecun.com/exdb/mnist/),其中包含6W张训练图片和1W张测试图片。至于Lenet的相关结构描述,可以参考这个链接:http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf。下面就详细讲述下,利用Deeplearning4j如何进行建模、训练和预测评估。

    首先,我们建立一个maven项目。然后在pom文件里加入Deeplearning4j的一些相关依赖。最主要的有三个:deeplearning4j-core,datavec,nd4j。deeplearning4j-core是神经网络结构实现的代码,nd4j是用于做张量运算的库,通过JavaCPP来调用编译好的C++库(可选:ATAL, MKL, 和OpenBLAS),datavec则主要负责数据的ETL。具体可见代码:

  UTF-8  0.7.1  0.7.1  0.7.1  2.10  
  
  
  org.nd4j  nd4j-native   ${nd4j.version}  
  
  org.deeplearning4j  dl4j-spark_2.11  ${dl4j.version}  
    org.datavec  datavec-spark_${scala.binary.version}  ${datavec.version}      org.deeplearning4j  deeplearning4j-core  ${dl4j.version}  
  
  

  1.     这些依赖里面有和Spark相关的,主要是跑Spark要用到。不过没有关系,先引进来即可。

    接着,我们解释下面的代码。我们先要定义一些具体的参数,比如分类的个数(outputNum),mini-batch的数量(batchSize)等等,具体在图中已经做了注释。需要说明的是MnistDataSetIterator这个迭代器类。这个类其实是一个读取二进制Mnist数据集的high-level的封装。通过debug我们可以发现,其中包括从网络中下载Mnist数据集,读取数据和标注,再构建迭代器的过程。在源码中,默认将下载的文件放在系统的user.home目录下,具体每个人不同会有所不同。由于我自己所处的环境网络不咋的,所以很有可能在利用这种high-level的接口的时候,因为下载Mnist数据失败而抛出异常,最终无法训练。所以,大家可以先自行下载好这些数据,然后按照源码的要求,放到相应的目录下并根据源码正确命名文件,那这样就依然可以利用这种high-level的接口。具体需要参考的是MnistDataFetcher类中相关代码。

int nChannels = 1; //black & white picture, 3 if color imageint outputNum = 10; //number of classificationint batchSize = 64; //mini batch size for sgdint nEpochs = 10; //total rounds of trainingint iterations = 1; //number of iteration in each traning roundint seed = 123; //random seed for initialize weightslog.info("Load data....");DataSetIterator mnistTrain = null;DataSetIterator mnistTest = null;mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

 

当我们正确读取数据后,我们需要定义具体的神经网络结构,这里我用的是Lenet,Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码如下:

MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(0.01)//.biasLearningRate(0.02)//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied.nIn(nChannels).stride(1, 1).nOut(20).activation("identity").build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build()).layer(2, new ConvolutionLayer.Builder(5, 5)//Note that nIn need not be specified in later layers.stride(1, 1).nOut(50).activation("identity").build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build()).layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation("softmax").build()).backprop(true).pretrain(false).cnnInputSize(28, 28, 1);// The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel//new ConvolutionLayerSetup(builder,28,28,1);MultiLayerConfiguration conf = builder.build();MultiLayerNetwork model = new MultiLayerNetwork(conf);model.init(); model.setListeners(new ScoreIterationListener(1)); // a listener which can print loss function score after each iteration

可以发现,神经网络需要定义很多的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是需要人为设定的。不同的超参数,对结果的影响很大,其实后来发现,很多时间都花在数据处理和调参方面。毕竟自己设计网络的能力有限,一般都是参考大牛的论文,然后自己照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积-->下采样-->卷积-->下采样-->全连接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考之前发的论文链接。这里我们设置了一个Score的监听事件,主要是可以在训练的时候获取每一次权重更新后损失函数的收敛情况。后面一会有截图。

定义完网络结构之后,我们就可以对之前读取的数据进行训练和分类准确性评估。先看下代码:

for( int i &#61; 0; i < nEpochs; &#43;&#43;i ) {  model.fit(mnistTrain);  log.info("*** Completed epoch " &#43; i &#43; "***");  log.info("Evaluate model....");  Evaluation eval &#61; new Evaluation(outputNum);  while(mnistTest.hasNext()){  DataSet ds &#61; mnistTest.next();            INDArray output &#61; model.output(ds.getFeatureMatrix(), false);  eval.eval(ds.getLabels(), output);  }  log.info(eval.stats());  mnistTest.reset();  

    相信这部分是比较容易理解的。每训练完一轮后&#xff0c;我们会对测试集合进行评估&#xff0c;然后打印出类似下面的结果。图中的上半部分是具体分类的统计&#xff0c;包括分对的和分错的图片数量都可以看得到。然后&#xff0c;我们耐心等待一段时间&#xff0c;可以看到经过10轮训练的Lenet对于Mnist数据集的分类准确率达到99%如下&#xff1a;

Examples labeled as 0 classified by model as 0: 974 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 0 classified by model as 9: 1 times
Examples labeled as 1 classified by model as 0: 1 times
Examples labeled as 1 classified by model as 1: 1128 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 1 classified by model as 3: 2 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 2 classified by model as 2: 1026 times
Examples labeled as 2 classified by model as 4: 1 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 7: 3 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 0: 1 times
Examples labeled as 3 classified by model as 1: 1 times
Examples labeled as 3 classified by model as 2: 1 times
Examples labeled as 3 classified by model as 3: 998 times
Examples labeled as 3 classified by model as 5: 3 times
Examples labeled as 3 classified by model as 7: 1 times
Examples labeled as 3 classified by model as 8: 4 times
Examples labeled as 3 classified by model as 9: 1 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 973 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 7: 1 times
Examples labeled as 4 classified by model as 9: 5 times
Examples labeled as 5 classified by model as 0: 2 times
Examples labeled as 5 classified by model as 3: 4 times
Examples labeled as 5 classified by model as 5: 882 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 2 times
Examples labeled as 6 classified by model as 0: 4 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 4: 1 times
Examples labeled as 6 classified by model as 5: 4 times
Examples labeled as 6 classified by model as 6: 945 times
Examples labeled as 6 classified by model as 8: 2 times
Examples labeled as 7 classified by model as 1: 5 times
Examples labeled as 7 classified by model as 2: 3 times
Examples labeled as 7 classified by model as 3: 1 times
Examples labeled as 7 classified by model as 7: 1016 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 2 times
Examples labeled as 8 classified by model as 0: 1 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 5: 2 times
Examples labeled as 8 classified by model as 7: 2 times
Examples labeled as 8 classified by model as 8: 966 times
Examples labeled as 8 classified by model as 9: 2 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 2 times
Examples labeled as 9 classified by model as 5: 4 times
Examples labeled as 9 classified by model as 6: 1 times
Examples labeled as 9 classified by model as 7: 5 times
Examples labeled as 9 classified by model as 8: 3 times
Examples labeled as 9 classified by model as 9: 993 times&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;Scores&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;Accuracy: 0.9901Precision: 0.99Recall: 0.99F1 Score: 0.99
&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;&#61;
[main] INFO cv.LenetMnistExample - ****************Example finished********************

    因为图传不上去&#xff0c;我就直接粘帖了结果。从中我们看到最终的一个准确率&#xff0c;还有就是哪些图片是分类正确的&#xff0c;哪些是分类错误的。当然我们可以通过增加训练的轮次还有调超参数来进一步优化&#xff0c;不过实际上这样的结果已经可以拿到生产上去用了。

    总结一下。其实包括我自己在内的很多人都对深度学习不了解&#xff0c;记得当时看csdn上写的有关深度学习的博客的时候&#xff0c;都觉得自己不可能达到那种水平。但其实&#xff0c;我们都忽略了一点&#xff0c;深度学习自身再复杂&#xff0c;它也是一个算法模型&#xff0c;也是一种机器学习。虽然它比感知机、逻辑回归等模型复杂很多&#xff08;其实逻辑回归可看作神经网络中的一个神经元&#xff0c;充当的是激励函数的作用&#xff0c;类似的激励函数很多&#xff0c;如tanh&#xff0c;relu等&#xff09;&#xff0c;但终究用它的目的依然是完成回归、分类、压缩数据等任务。所以第一步尝试还是挺重要的。当然&#xff0c;我们不可能从复杂的模型开始&#xff0c;一开始就跟上当下最流行的模型&#xff0c;所以就从Mnist识别的例子开始&#xff0c;找找感觉。以后会写一些用Deeplearning4j在Spark的案例&#xff0c;也还是从Mnist开始。分享的同时自己也复习一下。。。


转载于:https://my.oschina.net/u/2391658/blog/1507243


推荐阅读
  • Python入门后,想要从事自由职业可以做哪方面工作?1.爬虫很多人入门Python的必修课之一就是web开发和爬虫。但是这两项想要赚钱的话 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 本人学习笔记,知识点均摘自于网络,用于学习和交流(如未注明出处,请提醒,将及时更正,谢谢)OS:我学习是为了上 ... [详细]
  • 《Spark核心技术与高级应用》——1.2节Spark的重要扩展
    本节书摘来自华章社区《Spark核心技术与高级应用》一书中的第1章,第1.2节Spark的重要扩展,作者于俊向海代其锋马海平,更多章节内容可以访问云栖社区“华章社区”公众号查看1. ... [详细]
  • 2018年人工智能大数据的爆发,学Java还是Python?
    本文介绍了2018年人工智能大数据的爆发以及学习Java和Python的相关知识。在人工智能和大数据时代,Java和Python这两门编程语言都很优秀且火爆。选择学习哪门语言要根据个人兴趣爱好来决定。Python是一门拥有简洁语法的高级编程语言,容易上手。其特色之一是强制使用空白符作为语句缩进,使得新手可以快速上手。目前,Python在人工智能领域有着广泛的应用。如果对Java、Python或大数据感兴趣,欢迎加入qq群458345782。 ... [详细]
  • 本文介绍了Python高级网络编程及TCP/IP协议簇的OSI七层模型。首先简单介绍了七层模型的各层及其封装解封装过程。然后讨论了程序开发中涉及到的网络通信内容,主要包括TCP协议、UDP协议和IPV4协议。最后还介绍了socket编程、聊天socket实现、远程执行命令、上传文件、socketserver及其源码分析等相关内容。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
  • 图解redis的持久化存储机制RDB和AOF的原理和优缺点
    本文通过图解的方式介绍了redis的持久化存储机制RDB和AOF的原理和优缺点。RDB是将redis内存中的数据保存为快照文件,恢复速度较快但不支持拉链式快照。AOF是将操作日志保存到磁盘,实时存储数据但恢复速度较慢。文章详细分析了两种机制的优缺点,帮助读者更好地理解redis的持久化存储策略。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 数据结构与算法的重要性及基本概念、存储结构和算法分析
    数据结构与算法在编程领域中的重要性不可忽视,无论从事何种岗位,都需要掌握数据结构和算法。本文介绍了数据结构与算法的基本概念、存储结构和算法分析。其中包括线性结构、树结构、图结构、栈、队列、串、查找、排序等内容。此外,还介绍了图论算法、贪婪算法、分治算法、动态规划、随机化算法和回溯算法等高级数据结构和算法。掌握这些知识对于提高编程能力、解决问题具有重要意义。 ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有