热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch学习(三)卷积神经网络

目录卷积运算特征提取同一层的某些神经元是权值共享注意代码卷积过程补充多卷积核RNN神经网络卷积运算卷积运算是一种特征提取方式,卷积运算结果即为提取的特征࿰

目录

  • 卷积运算
    • 特征提取
    • 同一层的某些神经元是权值共享
    • 注意
    • 代码
    • 卷积过程
    • 补充
      • 多卷积核
    • RNN神经网络


卷积运算

卷积运算是一种特征提取方式,卷积运算结果即为提取的特征,卷积核像一个筛子,将图像中符合条件(激活值越大越符合条件)的特征筛选出来。
卷积神经网络由卷积核而得名。
在这里插入图片描述
卷积核的作用就是把前一层输入的数据特征提取出来,

特征提取

通过一个线性计算映射和一个激励函数(可选),最后在最后一层成为一个数值形式的激励值。
如图输入一个55的图形,图形中的每个像素都是一个具体的数字,每个像素值为RGB具体表现。为了表示方便,用0代表黑色,用1代表白色,表示一个黑白图。用一个33的卷积核对其y=wx+b进行卷积,其中w=[1,0,1,0,1,0,1,0,1],就是11+01+11+01+11+01+10+00+1*1+0=4
这个4会被投射到后方的Feature Map中,就是卷积特征中左上角的点。

同一层的某些神经元是权值共享


注意

这不是矩阵乘法,是两个向量做点积
在这个例子,w=[1,0,1,0,1,0,1,0,1],b=0
在一次扫描后,卷积核会从左到右,从上到下扫描整幅图片,并将每次计算的结果放到convolved feature
如果一个Feature map要与前面的图片的尺寸保持一致,通常需要在边缘补0,这种操作叫做padding,
在这里插入图片描述
扫描过程是跳跃的,可以一次移动一个像素进行扫描。striding =1 如图,如果padding =2,得到的feature map与输入图片的尺寸是相同的。

代码

在代码中,Kernel Size是卷积核的尺寸,卷积核的数目叫做channel,就是使用多少个这样的卷积核对图片进行扫描。在训练的时候使用梯度下降找出各个卷积核的w取什么值的时候模型有更好的表现,一般来说,卷积网络需要的数量比全连接网络更少,训练速度更快,泛化能力更强。

卷积过程

在这里插入图片描述

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001# MNIST dataset
# 读取训练集
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True, transform=transforms.ToTensor(),download=True)
#读取测试集
test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False, transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size, shuffle=False)# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()#定义第一个卷积层,卷积核的尺寸是5*5,输入通道为1,卷积核(输出通道)个数是16,步长是1像素,,padding=2self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),#对结果进行归一化,防止训练过程中各网络层的输入落入饱和区导致梯度消失,nn.BatchNorm2d(16),#使用激活函数对输出结果进行处理nn.ReLU(),#对图像输出采用最大池化的方法进行降采样nn.MaxPool2d(kernel_size=2, stride=2))#输入为16通道,输出为32通道,卷积核为5*5,padding=2self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),#进行批归一化nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))#一个全连接层,输入是7*7*32,输出是10个节点,就是将输出[1,7,*7*32]尺寸的向量作为x[1,1568]和一个[1568,10]的w相乘,再加上b,b是一个[1,10]的张量#7*7*32是如何得到的?如图self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outmodel = ConvNet(num_classes).to(device)# Loss and optimizer
#损失函数就是交叉熵损失函数
criterion = nn.CrossEntropyLoss()
#优化器是Adam
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

在这里插入图片描述
我们根据训练的结果可以看到比全连接层训练的结果准确率高。

补充

在卷积层是否需要padding ?根据输入图像边缘携带信息特征是否足以影响分类的准确性来决定 。降采样的作用是一种典型的降维操作,可以使用最大池化和平均池化。表现了一种以一定的信息损失为代价来换取空间和运算时间的取舍态度。
关于是用最大池化和平均池化根据自己对输出的要求。卷积层可以在任何网络结构中出现,是作为特征提取,池化层的作用是降维。降低模型的训练难度。
关于卷积核的数量,有几个通道的输出,就定义多少个卷积核。

多卷积核

上面只有100个参数的时候,表明只有1个10*10的卷积核,显然特征提取是不充分的,可以添加n个卷积核,可以学习n种特征。

RNN神经网络

循环神经网络最流行的就是LSTM,RNN与FF和CNN是不同的,FF是全连接层实现,CNN是卷积层实现,LSTM是用循环层实现的,就是把前一次输入的内容以及中间的激励值和这一次的输入值一起作为网络的输入,这样如果上一次的输入对这一次的输入有影响,那么这种影响也可以被学习,特别是对于文本处理的时候,根据上下文来推断这个词语的意思,用LSTM是很好的。


推荐阅读
  • 探讨ChatGPT在法律和版权方面的潜在风险及影响,分析其作为内容创造工具的合法性和合规性。 ... [详细]
  • 本文介绍了如何在 C# 和 XNA 框架中实现一个自定义的 3x3 矩阵类(MMatrix33),旨在深入理解矩阵运算及其应用场景。该类参考了 AS3 Starling 和其他相关资源,以确保算法的准确性和高效性。 ... [详细]
  • 深入解析Java枚举及其高级特性
    本文详细介绍了Java枚举的概念、语法、使用规则和应用场景,并探讨了其在实际编程中的高级应用。所有相关内容已收录于GitHub仓库[JavaLearningmanual](https://github.com/Ziphtracks/JavaLearningmanual),欢迎Star并持续关注。 ... [详细]
  • 本文将详细探讨 Java 中提供的不可变集合(如 `Collections.unmodifiableXXX`)和同步集合(如 `Collections.synchronizedXXX`)的实现原理及使用方法,帮助开发者更好地理解和应用这些工具。 ... [详细]
  • 本文探讨了如何使用pg-promise库在PostgreSQL中高效地批量插入多条记录,包括通过事务和单一查询两种方法。 ... [详细]
  • 本教程详细介绍了如何使用 TensorFlow 2.0 构建和训练多层感知机(MLP)网络,涵盖回归和分类任务。通过具体示例和代码实现,帮助初学者快速掌握 TensorFlow 的核心概念和操作。 ... [详细]
  • 社交网络中的级联行为 ... [详细]
  • 本文深入探讨了SQL数据库中常见的面试问题,包括如何获取自增字段的当前值、防止SQL注入的方法、游标的作用与使用、索引的形式及其优缺点,以及事务和存储过程的概念。通过详细的解答和示例,帮助读者更好地理解和应对这些技术问题。 ... [详细]
  • 本文探讨了如何在 F# Interactive (FSI) 中通过 AddPrinter 和 AddPrintTransformer 方法自定义类型(尤其是集合类型)的输出格式,提供了详细的指南和示例代码。 ... [详细]
  • 本文探讨了如何通过预处理器开关选择不同的类实现,并解决在特定情况下遇到的链接器错误。 ... [详细]
  • 本文详细探讨了Android Activity中View的绘制流程和动画机制,包括Activity的生命周期、View的测量、布局和绘制过程以及动画对View的影响。通过实验验证,澄清了一些常见的误解,并提供了代码示例和执行结果。 ... [详细]
  • LeetCode 690:计算员工的重要性评分
    在解决LeetCode第690题时,我记录了详细的解题思路和方法。该问题要求根据员工的ID计算其重要性评分,包括直接和间接下属的重要性。本文将深入探讨如何使用哈希表(Map)来高效地实现这一目标。 ... [详细]
  • Python 工具推荐 | PyHubWeekly 第二十一期:提升命令行体验的五大工具
    本期 PyHubWeekly 为大家精选了 GitHub 上五个优秀的 Python 工具,涵盖金融数据可视化、终端美化、国际化支持、图像增强和远程 Shell 环境配置。欢迎关注并参与项目。 ... [详细]
  • 一个登陆界面
    预览截图html部分123456789101112用户登入1314邮箱名称邮箱为空15密码密码为空16登 ... [详细]
  • 理解与应用:独热编码(One-Hot Encoding)
    本文详细介绍了独热编码(One-Hot Encoding)与哑变量编码(Dummy Encoding)两种方法,用于将分类变量转换为数值形式,以便于机器学习算法处理。文章不仅解释了这两种编码方式的基本原理,还探讨了它们在实际应用中的差异及选择依据。 ... [详细]
author-avatar
kg9854997
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有