热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

抑制过拟合的方法之Dropout(随机删除神经元)

在抑制过拟合的方法中,我们前面有讲到一个方法:抑制过拟合的方法之权值衰减,在某种程度上能够很好的抑制过拟合,如果神经网络的模型很复杂&#x

        在抑制过拟合的方法中,我们前面有讲到一个方法:抑制过拟合的方法之权值衰减 ,在某种程度上能够很好的抑制过拟合,如果神经网络的模型很复杂,只用权值衰减就难以应对了,这样的情况下,我们一般选择Dropout方法,也就是在训练的过程中,随机选出隐藏层的神经元,然后将其删除,被删除的神经元不再进行信号的传递。代码在权值衰减中有出现,layers.py里面,摘录出来

class Dropout:'''随机删除神经元self.mask:保存的是False和True的数组,False的值为0是删除的数据'''def __init__(self,dropout_ratio=0.5):self.dropout_ratio=dropout_ratioself.mask=Nonedef forward(self,x,train_flg=True):if train_flg:self.mask=np.random.rand(*x.shape)>self.dropout_ratioreturn x*self.maskelse:return x*(1.0-self.dropout_ratio)def backward(self,dout):return dout*self.mask

随机删除的意思是指每次正向传播时,self.mask中都会以False的形式保存要删除的神经元。


np.random.rand(2,3) 随机生成[0,1)形状为(2,3)的数组
np.random.rand(2,3)>0.5 把大于0.5的值设为True,其余为False(而不是删除一半的意思,因为数据是随机的)
x * self.mask 结果就是False为0,True还是x原来的值


        正向传播时传递了信号的神经元,反向传播时按照原样传递信号,正向传播时没有传递信号的神经元,反向传播时信号将停在那里。
        现在我们来比较使用Dropout和不使用Dropout的情况,还是基于MNIST数据集来测试
训练类(common.trainer.py)

import numpy as np
from common.optimizer import *class Trainer:'''把前面用来训练的代码做一个类'''def __init__(self,network,x_train,t_train,x_test,t_test,epochs=20,mini_batch_size=100,optimizer='SGD',optimizer_param={'lr':0.01},evaluate_sample_num_per_epoch=None,verbose=True):self.network=networkself.verbose=verbose#是否打印数据(调试或查看)self.x_train=x_trainself.t_train=t_trainself.x_test=x_testself.t_test=t_testself.epochs=epochsself.batch_size=mini_batch_sizeself.evaluate_sample_num_per_epoch=evaluate_sample_num_per_epochoptimizer_dict={'sgd':SGD,'momentum':Momentum,'nesterov':Nesterov,'adagrad':AdaGrad,'rmsprop':RMSprop,'adam':Adam}self.optimizer=optimizer_dict[optimizer.lower()](**optimizer_param)self.train_size=x_train.shape[0]self.iter_per_epoch=max(self.train_size/mini_batch_size,1)self.max_iter=int(epochs*self.iter_per_epoch)self.current_iter=0self.current_epoch=0self.train_loss_list=[]self.train_acc_list=[]self.test_acc_list=[]def train_step(self):batch_mask=np.random.choice(self.train_size,self.batch_size)x_batch=self.x_train[batch_mask]t_batch=self.t_train[batch_mask]grads=self.network.gradient(x_batch,t_batch)self.optimizer.update(self.network.params,grads)loss=self.network.loss(x_batch,t_batch)self.train_loss_list.append(loss)if self.verbose:print('训练损失值:'+str(loss))if self.current_iter%self.iter_per_epoch==0:self.current_epoch+=1x_train_sample,t_train_sample=self.x_train,self.t_trainx_test_sample,t_test_sample=self.x_test,self.t_testif not self.evaluate_sample_num_per_epoch is None:t=self.evaluate_sample_num_per_epochx_train_sample,t_train_sample=self.x_test[:t],self.t_test[:t] train_acc=self.network.accuracy(x_train_sample,t_train_sample)test_acc=self.network.accuracy(x_test_sample,t_test_sample) self.train_acc_list.append(train_acc)self.test_acc_list.append(test_acc)if self.verbose:print('epoch:'+str(self.current_epoch)+',train acc:'+str(train_acc)+' | test acc:'+str(test_acc))self.current_iter+=1def train(self):for i in range(self.max_iter):self.train_step()test_acc=self.network.accuracy(self.x_test,self.t_test)if self.verbose:print('最终测试的正确率:'+str(format(test_acc,'.2%')))

import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net_extend import MultiLayerNetExtend
from common.trainer import Trainer(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True)
#截取少量数据,让它再现过拟合
x_train=x_train[:300]#(300,784)
t_train=t_train[:300]#构建7层神经网络(6个隐藏层)
epochsNum=300
network=MultiLayerNetExtend(inputSize=784,hiddenSizeList=[100,100,100,100,100,100],outputSize=10,use_dropout=True,dropout_ration=0.2)
trainer=Trainer(network,x_train,t_train,x_test,t_test,epochs=epochsNum,mini_batch_size=100,optimizer='sgd',optimizer_param={'lr':0.01},verbose=True)
trainer.train()#画图
train_acc_list,test_acc_list=trainer.train_acc_list,trainer.test_acc_list
x=np.arange(len(train_acc_list))
plt.plot(x,train_acc_list,marker='s',label='train',markevery=10)
plt.plot(x,test_acc_list,marker='d',label='test',markevery=10)
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.ylim(0,1.0)
plt.legend(loc='lower right')
plt.show()

        不使用Dropout的情况,修改成use_dropout=False,我们会发现下图中,train数据过拟合了,所以很多时候我们都会优选Dropout来抑制过拟合。

其中需要用到的多层神经网络扩展版本(支持Dropout)multi_layer_net_extend.py

基于MNIST数据集的Batch Normalization(批标准化层)https://blog.csdn.net/weixin_41896770/article/details/121557928


推荐阅读
  • 本地存储组件实现对IE低版本浏览器的兼容性支持 ... [详细]
  • 本项目通过Python编程实现了一个简单的汇率转换器v1.02。主要内容包括:1. Python的基本语法元素:(1)缩进:用于表示代码的层次结构,是Python中定义程序框架的唯一方式;(2)注释:提供开发者说明信息,不参与实际运行,通常每个代码块添加一个注释;(3)常量和变量:用于存储和操作数据,是程序执行过程中的重要组成部分。此外,项目还涉及了函数定义、用户输入处理和异常捕获等高级特性,以确保程序的健壮性和易用性。 ... [详细]
  • 2018 HDU 多校联合第五场 G题:Glad You Game(线段树优化解法)
    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6356在《Glad You Game》中,Steve 面临一个复杂的区间操作问题。该题可以通过线段树进行高效优化。具体来说,线段树能够快速处理区间更新和查询操作,从而大大提高了算法的效率。本文详细介绍了线段树的构建和维护方法,并给出了具体的代码实现,帮助读者更好地理解和应用这一数据结构。 ... [详细]
  • 在 Linux 环境下,多线程编程是实现高效并发处理的重要技术。本文通过具体的实战案例,详细分析了多线程编程的关键技术和常见问题。文章首先介绍了多线程的基本概念和创建方法,然后通过实例代码展示了如何使用 pthreads 库进行线程同步和通信。此外,还探讨了多线程程序中的性能优化技巧和调试方法,为开发者提供了宝贵的实践经验。 ... [详细]
  • 本文总结了JavaScript的核心知识点和实用技巧,涵盖了变量声明、DOM操作、事件处理等重要方面。例如,通过`event.srcElement`获取触发事件的元素,并使用`alert`显示其HTML结构;利用`innerText`和`innerHTML`属性分别设置和获取文本内容及HTML内容。此外,还介绍了如何在表单中动态生成和操作``元素,以便更好地处理用户输入。这些技巧对于提升前端开发效率和代码质量具有重要意义。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 浅层神经网络解析:本文详细探讨了两层神经网络(即一个输入层、一个隐藏层和一个输出层)的结构与工作原理。通过吴恩达教授的课程,读者将深入了解浅层神经网络的基本概念、参数初始化方法以及前向传播和反向传播的具体实现步骤。此外,文章还介绍了如何利用这些基础知识解决实际问题,并提供了丰富的实例和代码示例。 ... [详细]
  • 视觉图像的生成机制与英文术语解析
    近期,Google Brain、牛津大学和清华大学等多家研究机构相继发布了关于多层感知机(MLP)在视觉图像分类中的应用成果。这些研究深入探讨了MLP在视觉任务中的工作机制,并解析了相关技术术语,为理解视觉图像生成提供了新的视角和方法。 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • 字节流(InputStream和OutputStream),字节流读写文件,字节流的缓冲区,字节缓冲流
    字节流抽象类InputStream和OutputStream是字节流的顶级父类所有的字节输入流都继承自InputStream,所有的输出流都继承子OutputStreamInput ... [详细]
  • 思科IOS XE与ISE集成实现TACACS认证配置
    本文详细介绍了如何在思科IOS XE设备上配置TACACS认证,并通过ISE(Identity Services Engine)进行用户管理和授权。配置包括网络拓扑、设备设置和ISE端的具体步骤。 ... [详细]
  • 本文记录了 JavaScript 中正则表达式的使用方法和常见操作,包括匹配、替换、搜索等。 ... [详细]
  • 本文详细解析了客户端与服务器之间的交互过程,重点介绍了Socket通信机制。IP地址由32位的4个8位二进制数组成,分为网络地址和主机地址两部分。通过使用 `ipconfig /all` 命令,用户可以查看详细的IP配置信息。此外,文章还介绍了如何使用 `ping` 命令测试网络连通性,例如 `ping 127.0.0.1` 可以检测本机网络是否正常。这些技术细节对于理解网络通信的基本原理具有重要意义。 ... [详细]
  • 分享一款基于Java开发的经典贪吃蛇游戏实现
    本文介绍了一款使用Java语言开发的经典贪吃蛇游戏的实现。游戏主要由两个核心类组成:`GameFrame` 和 `GamePanel`。`GameFrame` 类负责设置游戏窗口的标题、关闭按钮以及是否允许调整窗口大小,并初始化数据模型以支持绘制操作。`GamePanel` 类则负责管理游戏中的蛇和苹果的逻辑与渲染,确保游戏的流畅运行和良好的用户体验。 ... [详细]
  • TensorFlow Lite在移动设备上的部署实践与优化笔记
    近期在探索如何将服务器端的模型迁移到移动设备上,并记录了一些关键问题和解决方案。本文假设读者具备以下基础知识:了解TensorFlow的计算图(Graph)、图定义(GraphDef)和元图定义(MetaGraphDef)。此外,文中还详细介绍了模型转换、性能优化和资源管理等方面的实践经验,为开发者提供有价值的参考。 ... [详细]
author-avatar
小辉0110_737
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有