热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch学习(三)卷积神经网络

目录卷积运算特征提取同一层的某些神经元是权值共享注意代码卷积过程补充多卷积核RNN神经网络卷积运算卷积运算是一种特征提取方式,卷积运算结果即为提取的特征࿰

目录

  • 卷积运算
    • 特征提取
    • 同一层的某些神经元是权值共享
    • 注意
    • 代码
    • 卷积过程
    • 补充
      • 多卷积核
    • RNN神经网络


卷积运算

卷积运算是一种特征提取方式,卷积运算结果即为提取的特征,卷积核像一个筛子,将图像中符合条件(激活值越大越符合条件)的特征筛选出来。
卷积神经网络由卷积核而得名。
在这里插入图片描述
卷积核的作用就是把前一层输入的数据特征提取出来,

特征提取

通过一个线性计算映射和一个激励函数(可选),最后在最后一层成为一个数值形式的激励值。
如图输入一个55的图形,图形中的每个像素都是一个具体的数字,每个像素值为RGB具体表现。为了表示方便,用0代表黑色,用1代表白色,表示一个黑白图。用一个33的卷积核对其y=wx+b进行卷积,其中w=[1,0,1,0,1,0,1,0,1],就是11+01+11+01+11+01+10+00+1*1+0=4
这个4会被投射到后方的Feature Map中,就是卷积特征中左上角的点。

同一层的某些神经元是权值共享


注意

这不是矩阵乘法,是两个向量做点积
在这个例子,w=[1,0,1,0,1,0,1,0,1],b=0
在一次扫描后,卷积核会从左到右,从上到下扫描整幅图片,并将每次计算的结果放到convolved feature
如果一个Feature map要与前面的图片的尺寸保持一致,通常需要在边缘补0,这种操作叫做padding,
在这里插入图片描述
扫描过程是跳跃的,可以一次移动一个像素进行扫描。striding =1 如图,如果padding =2,得到的feature map与输入图片的尺寸是相同的。

代码

在代码中,Kernel Size是卷积核的尺寸,卷积核的数目叫做channel,就是使用多少个这样的卷积核对图片进行扫描。在训练的时候使用梯度下降找出各个卷积核的w取什么值的时候模型有更好的表现,一般来说,卷积网络需要的数量比全连接网络更少,训练速度更快,泛化能力更强。

卷积过程

在这里插入图片描述

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001# MNIST dataset
# 读取训练集
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True, transform=transforms.ToTensor(),download=True)
#读取测试集
test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False, transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size, shuffle=False)# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()#定义第一个卷积层,卷积核的尺寸是5*5,输入通道为1,卷积核(输出通道)个数是16,步长是1像素,,padding=2self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),#对结果进行归一化,防止训练过程中各网络层的输入落入饱和区导致梯度消失,nn.BatchNorm2d(16),#使用激活函数对输出结果进行处理nn.ReLU(),#对图像输出采用最大池化的方法进行降采样nn.MaxPool2d(kernel_size=2, stride=2))#输入为16通道,输出为32通道,卷积核为5*5,padding=2self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),#进行批归一化nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))#一个全连接层,输入是7*7*32,输出是10个节点,就是将输出[1,7,*7*32]尺寸的向量作为x[1,1568]和一个[1568,10]的w相乘,再加上b,b是一个[1,10]的张量#7*7*32是如何得到的?如图self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outmodel = ConvNet(num_classes).to(device)# Loss and optimizer
#损失函数就是交叉熵损失函数
criterion = nn.CrossEntropyLoss()
#优化器是Adam
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

在这里插入图片描述
我们根据训练的结果可以看到比全连接层训练的结果准确率高。

补充

在卷积层是否需要padding ?根据输入图像边缘携带信息特征是否足以影响分类的准确性来决定 。降采样的作用是一种典型的降维操作,可以使用最大池化和平均池化。表现了一种以一定的信息损失为代价来换取空间和运算时间的取舍态度。
关于是用最大池化和平均池化根据自己对输出的要求。卷积层可以在任何网络结构中出现,是作为特征提取,池化层的作用是降维。降低模型的训练难度。
关于卷积核的数量,有几个通道的输出,就定义多少个卷积核。

多卷积核

上面只有100个参数的时候,表明只有1个10*10的卷积核,显然特征提取是不充分的,可以添加n个卷积核,可以学习n种特征。

RNN神经网络

循环神经网络最流行的就是LSTM,RNN与FF和CNN是不同的,FF是全连接层实现,CNN是卷积层实现,LSTM是用循环层实现的,就是把前一次输入的内容以及中间的激励值和这一次的输入值一起作为网络的输入,这样如果上一次的输入对这一次的输入有影响,那么这种影响也可以被学习,特别是对于文本处理的时候,根据上下文来推断这个词语的意思,用LSTM是很好的。


推荐阅读
  • 基因组浏览器中的Wig格式解析
    本文详细介绍了Wiggle(Wig)格式及其在基因组浏览器中的应用,涵盖variableStep和fixedStep两种主要格式的特点、适用场景及具体使用方法。同时,还提供了关于数据值和自定义参数的补充信息。 ... [详细]
  • 基于KVM的SRIOV直通配置及性能测试
    SRIOV介绍、VF直通配置,以及包转发率性能测试小慢哥的原创文章,欢迎转载目录?1.SRIOV介绍?2.环境说明?3.开启SRIOV?4.生成VF?5.VF ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 本文详细介绍了如何构建一个高效的UI管理系统,集中处理UI页面的打开、关闭、层级管理和页面跳转等问题。通过UIManager统一管理外部切换逻辑,实现功能逻辑分散化和代码复用,支持多人协作开发。 ... [详细]
  • andr ... [详细]
  • Scala 实现 UTF-8 编码属性文件读取与克隆
    本文介绍如何使用 Scala 以 UTF-8 编码方式读取属性文件,并实现属性文件的克隆功能。通过这种方式,可以确保配置文件在多线程环境下的一致性和高效性。 ... [详细]
  • MySQL索引详解与优化
    本文深入探讨了MySQL中的索引机制,包括索引的基本概念、优势与劣势、分类及其实现原理,并详细介绍了索引的使用场景和优化技巧。通过具体示例,帮助读者更好地理解和应用索引以提升数据库性能。 ... [详细]
  • 深入探讨CPU虚拟化与KVM内存管理
    本文详细介绍了现代服务器架构中的CPU虚拟化技术,包括SMP、NUMA和MPP三种多处理器结构,并深入探讨了KVM的内存虚拟化机制。通过对比不同架构的特点和应用场景,帮助读者理解如何选择最适合的架构以优化性能。 ... [详细]
  • 配置Windows操作系统以确保DAW(数字音频工作站)硬件和软件的高效运行可能是一个复杂且令人沮丧的过程。本文提供了一系列专业建议,帮助你优化Windows系统,确保录音和音频处理的流畅性。 ... [详细]
  • 本文详细介绍了网络存储技术的基本概念、分类及应用场景。通过分析直连式存储(DAS)、网络附加存储(NAS)和存储区域网络(SAN)的特点,帮助读者理解不同存储方式的优势与局限性。 ... [详细]
  • 深入解析Redis内存对象模型
    本文详细介绍了Redis内存对象模型的关键知识点,包括内存统计、内存分配、数据存储细节及优化策略。通过实际案例和专业分析,帮助读者全面理解Redis内存管理机制。 ... [详细]
author-avatar
kg9854997
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有