热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Caffe入门(训练mnist)

使用caffe训练模型只需要以下几个步骤:(1)准备好数据;(2)写好模型配置文件;(3)写好优化配置文件;(4)命令行执行;这样就可以得到训练的模型.caffemodel文件了1

使用caffe训练模型只需要以下几个步骤:

(1)准备好数据;

(2)写好模型配置文件;

(3)写好优化配置文件;

(4)命令行执行;

这样就可以得到训练的模型.caffemodel文件了


1.caffe的下载与安装:

(1)下载

(2)安装

(3)caffe的下载与安装以及一些基本的介绍官网已经描述地比较详细,这里不再重复;

2.caffe的使用接口有命令行,python跟matlab,个人觉得训练模型的时候使用命令行已经足够简单了,至于训练好的模型可以使用python与matlab的接口进行调用,本文先描述基于命令行的模型训练,以LeNet模型为例;

LeNet模型

Caffe 入门(训练mnist)

PS:LeNet是手写数字库MNIST上应用比较经典的模型,具有7层网络结构,分别是卷积-下采样-卷积-下采样-全连-全连-分类层,具体网络细节可以参考文章:

Gradient based learning applied to document recognition

1.安装编译完caffe后,其主目录下有:

Caffe 入门(训练mnist)

2.训练模型之前需要先准备好训练数据MNIST,执行以下命令可以下载MNIST数据库:

Caffe 入门(训练mnist)

3.由于caffe支持的数据类型不包括图像类型,所以常规做法需要将图像类型转为lmdb类型:

Caffe 入门(训练mnist)

4.准备好数据之后,我们需要定义我们的网络模型,在caffe中是通过.prototxt配置文件来定义的,执行以下命令:

Caffe 入门(训练mnist)

可以看到各个网络层是如何定义的:

(1)输入层(数据层):

[plain] view plain copy
  1. layer {  
  2.   name: "mnist"      //表示层名  
  3.   type: "Data"       //表示层的类型  
  4.   top: "data"          
  5.   top: "label"  
  6.   include {  
  7.     phase: TRAIN      //表示仅在训练阶段起作用  
  8.   }  
  9.   transform_param {  
  10.     scale: 0.00390625  //将图像像素值归一化  
  11.   }  
  12.   data_param {  
  13.     source: "examples/mnist/mnist_train_lmdb"   //数据来源  
  14.     batch_size: 64                              //训练时每个迭代的输入样本数量  
  15.     backend: LMDB                               //数据类型  
  16.   }  
  17. }  
(2)卷积层:

[plain] view plain copy
  1. layer {  
  2.   name: "conv1"  
  3.   type: "Convolution"  
  4.   bottom: "data"               //输入是data  
  5.   top: "conv1"                 //输出是卷积特征  
  6.   param {  
  7.     lr_mult: 1                 //权重参数w的学习率倍数  
  8.   }  
  9.   param {  
  10.     lr_mult: 2                 //偏置参数b的学习率倍数  
  11.   }  
  12.   convolution_param {  
  13.     num_output: 20  
  14.     kernel_size: 5  
  15.     stride: 1  
  16.     weight_filler {           //权重参数w的初始化方案,使用xavier算法  
  17.       type: "xavier"  
  18.     }  
  19.     bias_filler {  
  20.       type: "constant"       //偏置参数b初始化化为常数,一般为0  
  21.     }  
  22.   }  
  23. }  

(3)下采样层(pool):

[plain] view plain copy
  1. layer {  
  2.   name: "pool1"  
  3.   type: "Pooling"  
  4.   bottom: "conv1"  
  5.   top: "pool1"  
  6.   pooling_param {  
  7.     pool: MAX  
  8.     kernel_size: 2  
  9.     stride: 2  
  10.   }  
  11. }  

(4)全连层:

[plain] view plain copy
  1. layer {  
  2.   name: "ip1"  
  3.   type: "InnerProduct"  
  4.   bottom: "pool2"  
  5.   top: "ip1"  
  6.   param {  
  7.     lr_mult: 1  
  8.   }  
  9.   param {  
  10.     lr_mult: 2  
  11.   }  
  12.   inner_product_param {  
  13.     num_output: 500  
  14.     weight_filler {  
  15.       type: "xavier"  
  16.     }  
  17.     bias_filler {  
  18.       type: "constant"  
  19.     }  
  20.   }  
  21. }  
(5)非线性层:

[plain] view plain copy
  1. layer {  
  2.   name: "relu1"  
  3.   type: "ReLU"  
  4.   bottom: "ip1"  
  5.   top: "ip1"  
  6. }  
(6)准确率层(计算准确率):

[plain] view plain copy
  1. layer {  
  2.   name: "accuracy"  
  3.   type: "Accuracy"  
  4.   bottom: "ip2"  
  5.   bottom: "label"  
  6.   top: "accuracy"  
  7.   include {  
  8.     phase: TEST  
  9.   }  
  10. }  
(7)损失估计层:

[plain] view plain copy
  1. layer {  
  2.   name: "loss"  
  3.   type: "SoftmaxWithLoss"  
  4.   bottom: "ip2"  
  5.   bottom: "label"  
  6.   top: "loss"  
  7. }  

5.定义完网络模型,还需要配置关于模型优化的文件:

Caffe 入门(训练mnist)

配置文件如下:

[plain] view plain copy
  1. # The train/test net protocol buffer definition  
  2. net: "examples/mnist/lenet_train_test.prototxt"     //设定网络模型配置文件的路径  
  3. # test_iter specifies how many forward passes the test should carry out.  
  4. # In the case of MNIST, we have test batch size 100 and 100 test iterations,  
  5. # covering the full 10,000 testing images.  
  6. test_iter: 100  
  7. # Carry out testing every 500 training iterations.  
  8. test_interval: 500                               
  9. # The base learning rate, momentum and the weight decay of the network.  
  10. base_lr: 0.01  
  11. momentum: 0.9  
  12. weight_decay: 0.0005  
  13. # The learning rate policy  
  14. lr_policy: "inv"  
  15. gamma: 0.0001  
  16. power: 0.75  
  17. # Display every 100 iterations  
  18. display: 100  
  19. # The maximum number of iterations  
  20. max_iter: 10000  
  21. # snapshot intermediate results  
  22. snapshot: 5000  
  23. snapshot_prefix: "examples/mnist/lenet"  
  24. # solver mode: CPU or GPU  
  25. solver_mode: GPU  

6.接下来一步就是进行训练了,直接执行命令就可以:

Caffe 入门(训练mnist)

执行后可以看到:首先会读取配置文件初始化网络跟优化器:

Caffe 入门(训练mnist)

紧接着开始优化:

Caffe 入门(训练mnist)

可以看到训练过程中每100次迭代就会显示一个loss,每500次迭代就会计算一次test准确率,总共10000次迭代,这些都可以在配置文件中设置;

7.训练完之后的模型就保存在.caffemodel文件中,该文件可以被c,python,matlab等调用;



推荐阅读
  • 本文详细介绍了 PHP 中对象的生命周期、内存管理和魔术方法的使用,包括对象的自动销毁、析构函数的作用以及各种魔术方法的具体应用场景。 ... [详细]
  • malloc 是 C 语言中的一个标准库函数,全称为 memory allocation,即动态内存分配。它用于在程序运行时申请一块指定大小的连续内存区域,并返回该区域的起始地址。当无法预先确定内存的具体位置时,可以通过 malloc 动态分配内存。 ... [详细]
  • 开机自启动的几种方式
    0x01快速自启动目录快速启动目录自启动方式源于Windows中的一个目录,这个目录一般叫启动或者Startup。位于该目录下的PE文件会在开机后进行自启动 ... [详细]
  • 在Windows系统中安装TensorFlow GPU版的详细指南与常见问题解决
    在Windows系统中安装TensorFlow GPU版是许多深度学习初学者面临的挑战。本文详细介绍了安装过程中的每一个步骤,并针对常见的问题提供了有效的解决方案。通过本文的指导,读者可以顺利地完成安装并避免常见的陷阱。 ... [详细]
  • 在软件开发过程中,经常需要将多个项目或模块进行集成和调试,尤其是当项目依赖于第三方开源库(如Cordova、CocoaPods)时。本文介绍了如何在Xcode中高效地进行多项目联合调试,分享了一些实用的技巧和最佳实践,帮助开发者解决常见的调试难题,提高开发效率。 ... [详细]
  • 本文对SQL Server系统进行了基本概述,并深入解析了其核心功能。SQL Server不仅提供了强大的数据存储和管理能力,还支持复杂的查询操作和事务处理。通过MyEclipse、SQL Server和Tomcat的集成开发环境,可以高效地构建银行转账系统。在实现过程中,需要确保表单参数与后台代码中的属性值一致,同时在Servlet中处理用户登录验证,以确保系统的安全性和可靠性。 ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • 通过将常用的外部命令集成到VSCode中,可以提高开发效率。本文介绍如何在VSCode中配置和使用自定义的外部命令,从而简化命令执行过程。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • WinMain 函数详解及示例
    本文详细介绍了 WinMain 函数的参数及其用途,并提供了一个具体的示例代码来解析 WinMain 函数的实现。 ... [详细]
  • Linux CentOS 7 安装PostgreSQL 9.5.17 (源码编译)
    近日需要将PostgreSQL数据库从Windows中迁移到Linux中,LinuxCentOS7安装PostgreSQL9.5.17安装过程特此记录。安装环境&#x ... [详细]
  • 字符串学习时间:1.5W(“W”周,下同)知识点checkliststrlen()函数的返回值是什么类型的?字 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • PTArchiver工作原理详解与应用分析
    PTArchiver工作原理及其应用分析本文详细解析了PTArchiver的工作机制,探讨了其在数据归档和管理中的应用。PTArchiver通过高效的压缩算法和灵活的存储策略,实现了对大规模数据的高效管理和长期保存。文章还介绍了其在企业级数据备份、历史数据迁移等场景中的实际应用案例,为用户提供了实用的操作建议和技术支持。 ... [详细]
  • Python 序列图分割与可视化编程入门教程
    本文介绍了如何使用 Python 进行序列图的快速分割与可视化。通过一个实际案例,详细展示了从需求分析到代码实现的全过程。具体包括如何读取序列图数据、应用分割算法以及利用可视化库生成直观的图表,帮助非编程背景的用户也能轻松上手。 ... [详细]
author-avatar
mobiledu2502869223
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有