热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

R︱SoftmaxRegression建模(MNIST手写体识别和文档多分类应用)

本文转载自经管之家论坛,R语言中的SoftmaxRegression建模(MNIST手写体识别和文档多分类应用)R中的softmaxreg包,发自20

本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

R中的softmaxreg包,发自2016-09-09,链接:https://cran.r-project.org/web/packages/softmaxreg/index.html


——————————————————————————————————————————————————————————————————


一、介绍

       Softmax Regression模型本质还是一个多分类模型,对Logistic Regression 逻辑回归的拓展。如果将Softmax Regression模型和神经网络隐含层结合起来,可以进一步提升模型的性能,构成包含多个隐含层和最后一个Softmax层的多层神经网络模型。之前发现R里面没有特别适合的方法支持多层的Softmax 模型,于是就想直接用R语言写一个softmaxreg 包。可以支持大部分的多分类问题,其中的两个示例:MNIST手写体识别和多文档分类(Multi-Class DocumentClassification) 的文档如下



二、示例文档

2.1 MNIST手写体识别数据集

MNIST手写体识别的数据集是图像识别领域一个基本数据集,很多模型诸如CNN卷积神经网络等模型都经常在这个数据集上测试都能够达到97%以上的准确率。 这里想比较一下包含隐含层的softmaxreg模型,测试结果显示模型的准确率能达到93% 左右。


Part1、下载和Load数据

      MNIST手写体识别的数据集可以直接从网站下载http://yann.lecun.com/exdb/mnist/,一共四个文件,分别下载下来并解压。文件格式比较特殊,可以用softmaxreg 包中的load_image_file 和load_label_file 两个函数读取。

训练集有60000幅图片,每个图片都是由16*16个像素构成,代表了0-9中的某一个数字,比如下图。


       利用softmaxreg 包训练一个10分类的MNIST手写体识别的模型,用load_image_file 和load_label_file 来分别读取训练集的图像数据和标签的数据 (Reference: brendano'connor - gist.github.com/39760的读取方法)


  1. library(softmaxreg)
  2. path= "D: \\DeepLearning\\MNIST\\"
  3. #10-classclassification, Digit 0-9
  4. x= load_image_file(paste(path,'train-images-idx3-ubyte', sep=""))
  5. y= load_label_file(paste(path,'train-labels-idx1-ubyte', sep=""))
  6. xTest= load_image_file(paste(path,'t10k-images-idx3-ubyte',sep=""))
  7. yTest= load_label_file(paste(path,'t10k-labels-idx1-ubyte', sep=""))



可以用show_digit函数来看一个数字的图像,比如查看某一个图片,比如第2副


  1. show_digit(x[2,])




Part2、训练模型

利用softmaxReg函数,训练集输入和标签分别为为x和y,maxit 设置最多多少个Epoch, algorithm为优化的算法,rate为学习率,batch参数为SGD随机梯度下降每个Mini-Batch的样本个数。 收敛后用predict方法来看看测试集Test的准确率怎么样


  1. ## Normalize Input Data
  2. x = x/255
  3. xTest = xTest/255
  4. model1= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "sgd", rate = 0.01, batch = 1000)
  5. loss1= model1$loss
  6. #Test Accuracy
  7. yFit= predict(model1, newdata = x)
  8. table(y,yFit)


Part3、比较不同优化算法的收敛速度


  1. model2= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "adagrad", rate = 0.01, batch =1000)
  2. loss2= model2$loss
  3. model3= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "rmsprop", rate = 0.01, batch =1000)
  4. loss3= model3$loss
  5. model4= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "momentum", rate = 0.01, batch= 1000)
  6. loss4= model4$loss
  7. model5= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "nag", rate = 0.01, batch = 1000)
  8. loss5= model5$loss
  9. #plot the loss convergence
  10. iteration= c(1:length(loss1))
  11. myplot= plot(iteration, loss1, xlab = "iteration", ylab = "loss",ylim = c(0, max(loss1,loss2,loss3,loss4,loss5) + 0.01),
  12.     type = "p", col ="black", cex = 0.7)
  13. title("ConvergenceComparision Between Learning Algorithms")
  14. points(iteration,loss2, col = "red", pch = 2, cex = 0.7)
  15. points(iteration,loss3, col = "blue", pch = 3, cex = 0.7)
  16. points(iteration,loss4, col = "green", pch = 4, cex = 0.7)
  17. points(iteration,loss5, col = "magenta", pch = 5, cex = 0.7)
  18. legend("topright",c("SGD", "Adagrad", "RMSprop","Momentum", "NAG"),
  19. col = c("black", "red","blue", "green", "magenta"),pch = c(1,2,3,4,5))
  20. save.image()






       如果maxit 迭代次数过大,模型运行时间较长,可以保存图像,最后可以看到AdaGrad, rmsprop,momentum, nag 和标准SGD这几种优化算法的收敛速度的比较效果。关于优化算法这个帖子有很好的总结:



http://cs231n.github.io/neural-networks-3/



2.2 多类别的文档分类

        Softmax regression模型的每个输入为一个文档,用一个字符串表示。其中每个词word都可以用一个word2vec模型训练的word Embedding低维度的实数词向量表示。在softmaxreg包中有一个预先训练好的模型:长度为20维的英文词向量的字典,直接用data(word2vec) 调用就可以了。

假设我们需要对UCI的C50新闻数据集进行分类,数据集包含多个作者写的新闻报道,每个作者的新闻文件都在一个单独的文件夹中。 我们假设挑选5个作者的文章进行训练softmax regression 模型,然后在测试集中预测任意文档属于哪一个作者,这就构成了一个5分类的问题。


Part1, 载入预先训练好的 英文word2vec 字典表


  1. library(softmaxreg)
  2. data(word2vec) # default 20 dimension word2vec dataset
  3. #### Reuter 50 DataSet UCI Archived Dataset from



Part2,利用loadURLData函数从网址下载数据并且解压到folder目录


  1. ## URL: "http://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip"
  2. URL = "http://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip"
  3. folder = getwd()
  4. loadURLData(URL, folder, unzip = TRUE)


Part3,利用wordEmbed() 函数作为lookup table,从默认的word2vec数据集中查找每个单词的向量表示,默认20维度,可以自己训练自己的字典数据集来替换。


  1. ##Training Data
  2. subFoler = c('AaronPressman', 'AlanCrosby', 'AlexanderSmith', 'BenjaminKangLim', 'BernardHickey')

  3. docTrain = document(path = paste(folder, "/C50train/",subFoler, sep = ""), pattern = 'txt')

  4. xTrain = wordEmbed(docTrain, dictionary = word2vec)
  5. yTrain = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
  6. # Assign labels to 5 different authors

  7. ##Testing Data
  8. docTest = document(path = paste(folder, "/C50test/",subFoler, sep = ""), pattern = 'txt')
  9. xTest = wordEmbed(docTest, dictionary = word2vec)
  10. yTest = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
  11. samp = sample(250, 50)
  12. xTest = xTest[samp,]
  13. yTest = yTest[samp]



Part4,训练模型,构建一个结构为20-10-5的模型,输入层为20维,即词向量的维度,隐含层的节点数为10,最后softmax层输出节点个数为5.


  1. ## Train Softmax Classification Model, 20-10-5
  2. softmax_model = softmaxReg(xTrain, yTrain, hidden = c(10), maxit = 500, type = "class",
  3. algorithm = "nag", rate = 0.05, batch = 10, L2 = TRUE)
  4. summary(softmax_model)
  5. yFit = predict(softmax_model, newdata = xTrain)
  6. table(yTrain, yFit)
  7. ## Testing
  8. yPred = predict(softmax_model, newdata = xTest)
  9. table(yTest, yPred)





增加
embedding
的维度到
50
或者
100
可以提升模型准确度;




相关资料:

关于Stanford的中英文

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92



softmaxregR包的下载地址:

https://cran.r-project.org/web/packages/softmaxreg/index.html



推荐阅读
  • 本文分析了Wince程序内存和存储内存的分布及作用。Wince内存包括系统内存、对象存储和程序内存,其中系统内存占用了一部分SDRAM,而剩下的30M为程序内存和存储内存。对象存储是嵌入式wince操作系统中的一个新概念,常用于消费电子设备中。此外,文章还介绍了主电源和后备电池在操作系统中的作用。 ... [详细]
  • 本文介绍了数据库的存储结构及其重要性,强调了关系数据库范例中将逻辑存储与物理存储分开的必要性。通过逻辑结构和物理结构的分离,可以实现对物理存储的重新组织和数据库的迁移,而应用程序不会察觉到任何更改。文章还展示了Oracle数据库的逻辑结构和物理结构,并介绍了表空间的概念和作用。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • WebSocket与Socket.io的理解
    WebSocketprotocol是HTML5一种新的协议。它的最大特点就是,服务器可以主动向客户端推送信息,客户端也可以主动向服务器发送信息,是真正的双向平等对话,属于服务器推送 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • VScode格式化文档换行或不换行的设置方法
    本文介绍了在VScode中设置格式化文档换行或不换行的方法,包括使用插件和修改settings.json文件的内容。详细步骤为:找到settings.json文件,将其中的代码替换为指定的代码。 ... [详细]
  • Webpack5内置处理图片资源的配置方法
    本文介绍了在Webpack5中处理图片资源的配置方法。在Webpack4中,我们需要使用file-loader和url-loader来处理图片资源,但是在Webpack5中,这两个Loader的功能已经被内置到Webpack中,我们只需要简单配置即可实现图片资源的处理。本文还介绍了一些常用的配置方法,如匹配不同类型的图片文件、设置输出路径等。通过本文的学习,读者可以快速掌握Webpack5处理图片资源的方法。 ... [详细]
  • Java序列化对象传给PHP的方法及原理解析
    本文介绍了Java序列化对象传给PHP的方法及原理,包括Java对象传递的方式、序列化的方式、PHP中的序列化用法介绍、Java是否能反序列化PHP的数据、Java序列化的原理以及解决Java序列化中的问题。同时还解释了序列化的概念和作用,以及代码执行序列化所需要的权限。最后指出,序列化会将对象实例的所有字段都进行序列化,使得数据能够被表示为实例的序列化数据,但只有能够解释该格式的代码才能够确定数据的内容。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 知识图谱——机器大脑中的知识库
    本文介绍了知识图谱在机器大脑中的应用,以及搜索引擎在知识图谱方面的发展。以谷歌知识图谱为例,说明了知识图谱的智能化特点。通过搜索引擎用户可以获取更加智能化的答案,如搜索关键词"Marie Curie",会得到居里夫人的详细信息以及与之相关的历史人物。知识图谱的出现引起了搜索引擎行业的变革,不仅美国的微软必应,中国的百度、搜狗等搜索引擎公司也纷纷推出了自己的知识图谱。 ... [详细]
  • 怎么在PHP项目中实现一个HTTP断点续传功能发布时间:2021-01-1916:26:06来源:亿速云阅读:96作者:Le ... [详细]
  • Redis底层数据结构之压缩列表的介绍及实现原理
    本文介绍了Redis底层数据结构之压缩列表的概念、实现原理以及使用场景。压缩列表是Redis为了节约内存而开发的一种顺序数据结构,由特殊编码的连续内存块组成。文章详细解释了压缩列表的构成和各个属性的含义,以及如何通过指针来计算表尾节点的地址。压缩列表适用于列表键和哈希键中只包含少量小整数值和短字符串的情况。通过使用压缩列表,可以有效减少内存占用,提升Redis的性能。 ... [详细]
  • 在Oracle11g以前版本中的的DataGuard物理备用数据库,可以以只读的方式打开数据库,但此时MediaRecovery利用日志进行数据同步的过 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
author-avatar
zeng-abee
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有