热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用mxnet实现卷积神经网络LeNet

LeNet是一个早期用来识别手写数字的卷积神经网络,这个名字来源于LeNet论文的第一作者YannLeCun。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时

LeNet是一个早期用来识别手写数字的卷积神经网络,这个名字来源于LeNet论文的第一作者Yann LeCun。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时最先进的成果,这个尊基性的工作第一次将卷积神经网络推上舞台

使用mxnet实现卷积神经网络LeNet

上图就是LeNet模型,下面将对每层参数进行说明

1.1 input输入层

假设输入层数据shape=(32,32)

1.2 C1卷积层

  • 卷积核大小: kernel_size=(5,5)
  • 步幅:stride = 1
  • 输出通道为6
  • 可训练参数为: (5 * 5 + 1) * 6
  • 激活函数:采用relu
    输入层数据经过C1卷积层后将得到feature maps形状(6 * 28 * 28),注:28 = 32 -5 + 1

1.3 S2池化层

池化层(Max Pooling)窗口形状均为2*2,步幅度为2,输出feature maps为(6 *14 * 14),6为feature map的数量

1.4 C3卷积层

  • 卷积核大小: kernel_size=(5,5)
  • 步幅:stride = 1
  • 输出通道为16
  • 激活函数:采用relu得到feature maps为(16 * 10 * 10),(10*10)为每个feature map形状,16为feature map数量

1.5 S4池化层

池化层(Max Pooling)窗口形状依然均为2*2,步幅度为2,输出feature maps为(16 *5 * 5),16为feature map的数量

1.6 C5全链接层

  • 输出120个神经元
  • 激活函数:relu

1.7 F6全连接层

  • 输出84个神经元
  • 激活函数:relu

1.8 output

  • 输出10个神经元
  • 激活函数:无

2.用Mxnet实现LeNet模型

import mxnet as mx
from mxnet import autograd,init,nd
from mxnet.gluon import nn,Trainer
from mxnet.gluon import data as gdata
from mxnet.gluon import loss as gloss
import time

class LeNet_mxnet:
    def __init__(self):
        self.net = nn.Sequential()
        self.net.add(nn.Conv2D(channels=6,kernel_size=5,activation='relu'),
                nn.MaxPool2D(pool_size =(2,2),strides=(2,2)),
                nn.Conv2D(channels=16,kernel_size=(5,5),strides=(1,1),padding=(0,0),activation='relu'),
                nn.MaxPool2D(pool_size =(2,2),strides=(2,2)),
                nn.Dense(units=120,activation='relu'),
                nn.Dense(units=84,activation='relu'),
                nn.Dense(units=10)  #最后一个全连接层激活函数取决于损失函数
               )
        
    def train(self,train_iter,test_iter,n_epochs,ctx):
        print('training on',ctx)
        self.net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
        trainer_op = Trainer(self.net.collect_params(),'adam',{'learning_rate':0.01})
        loss = gloss.SoftmaxCrossEntropyLoss()
        
        accuracy_val = 0
        for epoch in range(n_epochs):
            
            train_loss_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time()
            
            for x_batch,y_batch in train_iter:
                x_batch,y_batch = x_batch.as_in_context(ctx),y_batch.as_in_context(ctx)
                with autograd.record():
                    y_hat = self.net(x_batch)
                    loss_val = loss(y_hat,y_batch).sum()
                loss_val.backward()
                trainer_op.step(n_batches)
                y_batch = y_batch.astype('float32')
                train_loss_sum += loss_val.asscalar()
                train_acc_sum += (y_hat.argmax(axis=1) == y_batch).sum().asscalar()
                n += y_batch.size
            test_acc = self.accuracy_score(test_iter,ctx)
            accuracy_val += self.accuracy_score(test_iter,ctx)
            print('epoch:%d,train_loss:%.4f,train_acc:%.3f,test_acc:%.3f,time:%.1f sec' 
                  %(epoch+1, train_loss_sum / n, train_acc_sum/ n,test_acc,time.time() - start))
    
    def accuracy_score(self,data_iter,ctx):
        acc_sum,n = nd.array([0],ctx=ctx),0
        for x,y in data_iter:
            x,y = x.as_in_context(ctx),y.as_in_context(ctx)
            y = y.astype('float32')
            acc_sum += (self.net(x).argmax(axis=1) == y).sum()
            n += y.size
        return acc_sum.asscalar() / n
    
    def __call__(self,x):
        return self.net(x)
    
    def predict(self,x,ctx):
        x = x.as_in_context(ctx)
        return self.net(x).argmax(axis=1)
    
    def print_info(self):
        print(self.net[4].params)

3.使用mnist手写数字数据集进行测试

from tensorflow.keras.datasets import mnist

(x_train,y_train),(x_test,y_test) = mnist.load_data()
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)
x_train = x_train.reshape(60000,1,28,28).astype('float32')
x_test = x_test.reshape(10000,1,28,28).astype('float32')
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
lenet_mxnet = LeNet_mxnet()
epochs = 10
n_batches = 500
train_iter = gdata.DataLoader(gdata.ArrayDataset(x_train,y_train),batch_size=n_batches)
test_iter = gdata.DataLoader(gdata.ArrayDataset(x_test,y_test),batch_size=n_batches)
lenet_mxnet.train(train_iter,test_iter,epochs,ctx=mx.gpu())
training on gpu(0)
epoch:1,train_loss:1.8267,train_acc:0.571,test_acc:0.896,time:3.0 sec
epoch:2,train_loss:0.2449,train_acc:0.924,test_acc:0.948,time:2.6 sec
epoch:3,train_loss:0.1563,train_acc:0.952,test_acc:0.954,time:2.6 sec
epoch:4,train_loss:0.1302,train_acc:0.961,test_acc:0.962,time:2.5 sec
epoch:5,train_loss:0.1169,train_acc:0.964,test_acc:0.958,time:2.5 sec
epoch:6,train_loss:0.1017,train_acc:0.969,test_acc:0.967,time:2.5 sec
epoch:7,train_loss:0.0855,train_acc:0.973,test_acc:0.964,time:3.3 sec
epoch:8,train_loss:0.0848,train_acc:0.973,test_acc:0.964,time:3.6 sec
epoch:9,train_loss:0.0767,train_acc:0.976,test_acc:0.963,time:3.5 sec
epoch:10,train_loss:0.0771,train_acc:0.977,test_acc:0.970,time:3.5 sec
# 将预测结果可视化
import matplotlib.pyplot as plt

def plt_image(image):
    n = 20
    plt.figure(figsize=(20,4))
    for i in range(n):
        ax = plt.subplot(2,10,i+1)
        plt.imshow(x_test[i].reshape(28,28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()
    
plt_image(x_test)
print('predict result:',lenet_mxnet.predict(nd.array(x_test[0:20]),ctx=mx.gpu()))

使用mxnet实现卷积神经网络LeNet

predict result: 
[7. 2. 1. 0. 4. 1. 4. 9. 5. 9. 0. 6. 9. 0. 1. 5. 9. 7. 3. 4.]

4. 附:需要注意的知识点

  • (1) 注意SoftmaxCrossEntropyLoss的使用,hybrid_forward源码说明,若from_logits为False时(默认为Flase),会先通过log_softmax计算各分类的概率,再计算loss,同样SigmoidBinaryCrossEntropyLoss也提供了from_sigmoid参数决定是否在hybrid_forward函数中要计算sigmoid函数,所以在创建模型最后一层的时候要特别注意是否要给激活函数

  • (2) 注意权重初始化选择

  • (3) 注意(y_hat.argmax(axis=1) == y_batch)操作时y_batch数据类型转换

  • (4) 上面的模型没有对数据集进行归一化处理,可以添加该步骤


推荐阅读
  • 代码如下:#coding:utf-8importstring,os,sysimportnumpyasnpimportmatplotlib.py ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • EPPlus绘制刻度线的方法及示例代码
    本文介绍了使用EPPlus绘制刻度线的方法,并提供了示例代码。通过ExcelPackage类和List对象,可以实现在Excel中绘制刻度线的功能。具体的方法和示例代码在文章中进行了详细的介绍和演示。 ... [详细]
  • 程序分析与优化9附录XLA的缓冲区指派
    本章是系列文章的案例学习,不属于正篇,主要介绍了TensorFlow引入的XLA的优化算法。XLA也有很多局限性,XLA更多的是进行合并,但有时候如果参数特别多的场景下,也需要进行 ... [详细]
  • 干货 | 携程AI推理性能的自动化优化实践
    作者简介携程度假AI研发团队致力于为携程旅游事业部提供丰富的AI技术产品,其中性能优化组为AI模型提供全方位的优化方案,提升推理性能降低成本࿰ ... [详细]
author-avatar
婷寶Avrow
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有