热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch学习笔记(八):图像增强、ResNet完成Cifar10分类

一.图像增强的方法一直以来,图像识别这一计算机视觉的核心问题都面临很多挑战,同一个物体在不同情况下都会得出不同的结论。对于一张图片,我们看到的是一些物体,而计算机看到的是一些像素点

一. 图像增强的方法

一直以来,图像识别这一计算机视觉的核心问题都面临很多挑战,同一个物体在不同情况下都会得出不同的结论。对于一张图片,我们看到的是一些物体,而计算机看到的是一些像素点。

如果拍摄照片的照相机位置发生了改变,那么拍得的图片对于我们而言,变化很小,但是对于计算机而言,图片得像素变化是很大得。拍照时得光照条件也是很重要的一个影响因素:光照太弱,照片里的物体会和背景融为一体,它们的像素点就会很接近,计算机就无法正确识别出物体。除此之外,物体本身的变形也会对计算机识别造成障碍,比如一只猫是趴着的,计算机能够识别它,但如果猫换个姿势,变成躺着的状态,那么计算机就无法识别了。最后,物体本身会隐藏在一些遮蔽物中,这样物体只呈现出局部的信息,计算机也难以识别。

针对这些问题,我们希望可以对原始图片进行增强,在一定程度上解决部分问题。在PyTorch中已经内置了一些图像增强的方法,不需要再繁琐地去实现,只需要简单的调用。

torchvision.transforms包括所有图像增强的方法:

  • 第一个函数是 Scale,对图片的尺寸进行缩小或者放大;
  • 第二个函数是 CenterCrop,对图像正中心进行给定大小的裁剪;
  • 第三个函数是 RandomCrop,对图片进行给定大小的随机裁剪;
  • 第四个函数是 RandomHorizaontalFlip,对图片进行概率为0.5的随机水平翻转;
  • 第五个函数是 RandomSizedCrop,首先对图片进行随机尺寸的裁剪,然后再对裁剪的图片进行一个随机比例的缩放,最后将图片变成给定的大小,这在InceptionNet中比较流行;
  • 最后一个是 pad,对图片进行边界零填充;

上面介绍了PyTorch内置的一些图像增强的方法,还有更多的增强方法,可以使用OpenCV或者PIL等第三方图形库实现。在网络的训练的过程中图形增强是一种常见、默认的做法,对多任务进行图像增强之后能够在一定程度上提升任务的准确率。

二. 实现 CIFAR-10 分类

cifar 10数据集有60000张图片,每张图片都是 32×32 的三通道的彩色图,一共是10个类别,每种类别有6000张图片。下面实现ResNet来处理cifar 10数据集,完成图像分类。

注意的是下面的代码只对训练图片进行图像增强,提高其泛化能力,对于测试集,仅对其中心化,不做其他的图像增强。

import torch
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchstat import stat
from torch.autograd import Variable
# 读数据
def get_data():
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True)
print(len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, drop_last=True)
return train_loader, test_loader
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
# Residual Block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# 构建网络
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16 # 64, 3, 32, 32
self.conv = conv3x3(3, 16) # 64, 16, 32, 32
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(True)
self.layer1 = self.make_layer(block, 16, layers[0]) # 64, 16, 32, 32
self.layer2 = self.make_layer(block, 32, layers[0], 2) # 64, 32, 16, 16
self.layer3 = self.make_layer(block, 64, layers[1], 2) # 64, 64, 8, 8
self.avg_pool = nn.AvgPool2d(8) # 64, 64, 1, 1
self.fc = nn.Linear(64, num_classes)
def make_layer(self, block, out_channles, blocks, stride=1):
downsample = None
if out_channles != self.in_channels or stride != 1:
downsample = nn.Sequential(conv3x3(self.in_channels, out_channles, stride=stride), nn.BatchNorm2d(out_channles))
layers = []
layers.append(block(self.in_channels, out_channles, stride, downsample))
self.in_channels = out_channles
for i in range(1, blocks):
layers.append(block(out_channles, out_channles))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
if __name__ == "__main__":
# 超参数配置
batch_size = 64
learning_rate = 1e-2
num_epoches = 100
# 训练图片的预处理方式
train_transform = transforms.Compose([transforms.Scale(40), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32),
transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# 加载数据集
train_dataset, test_dataset = get_data()
# 构建模型
# model = ResNet(ResidualBlock, [3, 4])
model = torch.load('resnet_model.pth')
stat(model, (3, 32, 32))
if torch.cuda.is_available():
model = model.cuda()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
schedule_lr = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# 开始训练
for i in range(num_epoches):
j = 0
for img, label in train_dataset:
model.train()
schedule_lr.step()
img = Variable(img)
label = Variable(label)
# forward
out = model(img)
loss = criterion(out, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print
print("epoch= {},j= {}, loss is {}".format(i, j, loss))
#print(list(model.children())[-1].weight)
j += 1
if j % 100 == 0:
torch.save(model, './resnet_model.pth')
# test
model.eval()
count = 0
print(len(test_dataset))
for img, label in test_dataset:
img = Variable(img)
out = model(img)
_, predict = torch.max(out, 1)
if predict == label:
count += 1
print(count / len(test_dataset))


推荐阅读
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • [echarts] 同指标对比柱状图相关的知识介绍及应用示例
    本文由编程笔记小编为大家整理,主要介绍了echarts同指标对比柱状图相关的知识,包括对比课程通过率最高的8个课程和最低的8个课程以及全校的平均通过率。文章提供了一个应用示例,展示了如何使用echarts制作同指标对比柱状图,并对代码进行了详细解释和说明。该示例可以帮助读者更好地理解和应用echarts。 ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • 本文介绍了在MFC下利用C++和MFC的特性动态创建窗口的方法,包括继承现有的MFC类并加以改造、插入工具栏和状态栏对象的声明等。同时还提到了窗口销毁的处理方法。本文详细介绍了实现方法并给出了相关注意事项。 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • Vagrant虚拟化工具的安装和使用教程
    本文介绍了Vagrant虚拟化工具的安装和使用教程。首先介绍了安装virtualBox和Vagrant的步骤。然后详细说明了Vagrant的安装和使用方法,包括如何检查安装是否成功。最后介绍了下载虚拟机镜像的步骤,以及Vagrant镜像网站的相关信息。 ... [详细]
  • EPPlus绘制刻度线的方法及示例代码
    本文介绍了使用EPPlus绘制刻度线的方法,并提供了示例代码。通过ExcelPackage类和List对象,可以实现在Excel中绘制刻度线的功能。具体的方法和示例代码在文章中进行了详细的介绍和演示。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
author-avatar
夏天电艹热毯
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有