热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch:如何从头开始训练一个CNN网络?

文章目录前言一、CNN?二、用单批量测试模型1.引入库2.读入数据集3.建造Module实例4.训练总结前言在刚开始学习DeepLearning时,一


文章目录

  • 前言
  • 一、CNN?
  • 二、用单批量测试模型
    • 1.引入库
    • 2.读入数据集
    • 3. 建造Module实例
    • 4. 训练
  • 总结




前言

在刚开始学习Deep Learning时,一件几乎不可能的事情就是知道每一个东西背后的原理和用法。但是,很多人又不得不在前期涉猎很多在前期不应该碰的东西。多摸索是好事,但考虑到性价比,最好的办法是有人带着你从头实现一下你所需要做的。在此,我希望本文是目标导向型的——即与你一同从头实现出一个属于你自己的CNN。




一、CNN?

CNN是卷积神经网络,在此,我们力图简洁(让各种trick xx去吧)。我们不会考虑太多的trick,只希望和你一同,实现一个最简单最简单的CNN。


二、用单批量测试模型

称之为”single-batch training“。 事实上,它只是一个用来测试用的、在模型正式训练之前的一个测试。这一步仅仅是为了测试你模型是否有毛病,如果你对你自己的模型有信心,大可不必进行这一步。

在这一步,我们仅需要用一个batch来进行训练,观察其Loss能否降到0(亦为过拟合overfitting),如能,则进行我们的正式训练。


1.引入库

import cv2
import mathimport numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.hub import load_state_dict_from_urlfrom PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets

2.读入数据集

至于数据集的读入,你可以用老师给的数据库,也可以利用torchvision自带的数据库(如果允许)。我在这篇1里讲到了三种具体的实现方法。但是在这里,我们用老师给的数据库。存放到相对目录的“./train_set/train_set”中。这里存放了30个类(划重点,下面要讲到)。

在这里插入图片描述

我们利用ImageFolder 来进行操作。非常方便。它会自动地读取有几类(几个子文件夹),编号等等工作。

BATCH_SIZE = 64
# normalization
train_transform = transforms.Compose([transforms.ColorJitter(hue=0.2, saturation=0.2, brightness=0.2),transforms.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor()
])# introduce the dataset
trainset = ImageFolder(root="./train_set/train_set", transform=train_transform)
train_loader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=False)

有几个坑的地方要注意一下:


  • transform是将图片进行预处理的一个参数。这个参数是需要依赖你自己的想法做的。具体可以自己去查一下,这里不再详细赘述。
  • 如果你决定用”single-batch training“来勘察你的模型维度设置等是否正确,你需要将shuffle设置为False(不随机抽取)。函数参数具体含义已经在pytorch:数据读取操作里。

3. 建造Module实例

所谓Module,我们即可以将其实例化为一个层,也可以将其实例化为一个网络(没错就是这么灵活,允许套娃)在此,我们继承nn.Module,设置一个CNN_model类。在 __init__方法中,我们使用将这个网络主要分为两个主要的模块:卷积模块(self.cnn)和全连接层模块(self.linear)。

我们利用 self.cnn = nn.Sequential()函数来整理一个大模块。它是有序的,模块将按照在传入构造器的顺序依次被添加到计算图中执行。注意,这里有30类,所以我们要保证最后一层的输出为30。

class CNN_model(nn.Module):def __init__(self):super(CNN_model, self).__init__()# conv1 self.cnn = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2), # 16x16x650nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), # 32x16x650nn.ReLU(),nn.Dropout2d(0.5),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 64x16x650nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2), # 64x8x325nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU() # 64x8x325)self.linear = nn.Sequential(nn.Linear(32768, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 30)) def forward(self,x):x = self.cnn(x)fc_input=x.view(x.size(0),-1)fc_output = self.linear(fc_input)return fc_output# 将其实例化
cnn = CNN_model()
for param in cnn.parameters():print(param.shape)# 得到一个data,查看是否成功读取。
dataiter = iter(train_loader)
images, labels = dataiter.next()
print(labels)

4. 训练

在训练时,我们要确定我们需要什么损失函数。这又和具体的分类任务有关……如果想让本文篇幅稍微短点,恐怕还是绕过这个话题比较好。还记得我们所用的数据集吗?有30类对吧?我们用这CrossEntropyLoss()损失函数来完成分类的任务。

请仔细阅读代码与注释。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9)for epoch in range(nepochs): # loop over the dataset multiple timescorrect = 0 # number of examples predicted correctly (for accuracy)total = 0 # number of examplesrunning_loss = 0.0 # accumulated loss (for mean loss)n = 0 # number of minibatches# 每个epoch,我们都取第一个batch来进行训练。dataiter = iter(train_loader) inputs, labels = dataiter.next()# ================== main body=========================# 将梯度清空(每轮epoch都清空)optimizer.zero_grad()# Forward, backward, and update parametersoutputs = cnn(inputs)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# =====================================================# accumulate lossrunning_loss += loss.item()n += 1# accumulate data for accuracy_, predicted = torch.max(outputs.data, 1)total += labels.size(0) # add in the number of labels in this minibatchcorrect += (predicted == labels).sum().item() # add in the number of correct labels# collect together statistics for this epochltrn = running_loss/natrn = correct/total print(f"epoch: {epoch} training loss: {ltrn: .3f} training accuracy: {atrn: .1%} ")

总结

可以看到,核心代码区区不过几行。你可能会问:这只是单个batch的呀?需要多个batch只需要一个循环就解决了!在下一篇《pytorch:如何从头开始训练一个CNN网络?(二)》中,我们还会学习到一些使你涨点的技巧。记住我们的学习哲学:

我们在这里只是工程导向,并不详细阐述背后原理。


推荐阅读
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 图像因存在错误而无法显示 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • vue使用
    关键词: ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • OpenMap教程4 – 图层概述
    本文介绍了OpenMap教程4中关于地图图层的内容,包括将ShapeLayer添加到MapBean中的方法,OpenMap支持的图层类型以及使用BufferedLayer创建图像的MapBean。此外,还介绍了Layer背景标志的作用和OMGraphicHandlerLayer的基础层类。 ... [详细]
author-avatar
pierce2502910693
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有