热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:PaddlePaddle系列CIFAR10图像分类

本文由编程笔记#小编为大家整理,主要介绍了PaddlePaddle系列CIFAR-10图像分类相关的知识,希望对你有一定的参考价值。前言
本文由编程笔记#小编为大家整理,主要介绍了PaddlePaddle系列CIFAR-10图像分类相关的知识,希望对你有一定的参考价值。



前言

本文与前文对手写数字识别分类基本类似的,同样图像作为输入,类别作为输出。这里不同的是,不仅仅是使用简单的卷积神经网络加上全连接层的模型。卷积神经网络大火以来,发展出来许多经典的卷积神经网络模型,包括VGG、ResNet、AlexNet等等。下面将针对CIFAR-10数据集,对图像进行分类。

 


1、CIFAR-10数据集、Reader创建

CIFAR-10数据集分为5个batch的训练集和1个batch的测试集,每个batch包含10,000张图片。每张图像尺寸为32*32的RGB图像,且包含有标签。一共有10个标签:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck十个类别。

技术分享图片

我在CIFAR-10网站中下载的是[CIFAR-10 python version](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)。数据集完成后,解压得到上述六个文件。上述六个文件都是字典文件,使用cPickle模块即可读入。字典中‘data’需要重新定义维度为1000*32*32*3,维度分别代表[N H W C],即10,000张32*32尺寸的三通道(RGB)图像,再经过转换成为paddlepaddle读取的[N C H W ]维度形式;而字典‘labels’为10000个标签。如此一来,可以建立读取CIFAR-10的reader(与官方例程不同),如下:


技术分享图片技术分享图片

def reader_creator(ROOT,istrain=True,cycle=False):
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename,
rb) as f:
datadict
= Pickle.load(f)
X
= datadict[data]
Y
= datadict[labels]
""" (N C H W) transpose to (N H W C) """
X
= X.reshape(10000,3,32,32).transpose(0,2,3,1).astype(float)
Y
= np.array(Y)
return X,Y
def reader():
while True:
if istrain:
for b in range(1,6):
f
= os.path.join(ROOT,data_batch_%d%(b))
X,Y
= load_CIFAR_batch(f)
length
= X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
else:
f
= os.path.join(ROOT,test_batch)
X,Y
= load_CIFAR_batch(f)
length
= X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
return reader


View Code

 


2、VGG网络

VGG网络采用“减小卷积核大小,增加卷积核数量”的思想改造而成,这里直接采用paddlepaddle例程中的VGG网络了,值得提醒的是paddlepaddle中直接有函数img_conv_group提供卷积、池化、dropout一组操作,所以根据VGG的模型,前面卷积层可以划分为5组,然后再经过3层的全连接层得到结果。

技术分享图片

PaddlePaddle例程中根据上图D网络,加入dorpout:


def vgg_bn_drop(input):
def conv_block(ipt, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input
=ipt,
#一组的卷积层的卷积核总数,组成list[num_filter num_filter ...]
conv_num_filter=[num_filter] * groups,
conv_filter_size
=3,
conv_act
=relu,
conv_with_batchnorm
=True,
#每组卷积层各层的droput概率
conv_batchnorm_drop_rate=dropouts,
pool_size
=2,
pool_stride
=2,
pool_type
=max)
conv1
= conv_block(input, 64, 2, [0.3, 0]) #[0.3 0]即为第一组两层的dorpout概率,下同
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3
= conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4
= conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5
= conv_block(conv4, 512, 3, [0.4, 0.4, 0])
drop
= fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1
= fluid.layers.fc(input=drop, size=512, act=None)
bn
= fluid.layers.batch_norm(input=fc1, act=relu)
drop2
= fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2
= fluid.layers.fc(input=drop2, size=512, act=None)
predict
= fluid.layers.fc(input=fc2, size=10, act=softmax)
return predict

 


3、训练

训练程序与上一节例程一样,同样是选取交叉熵作为损失函数,不多累赘讲述。


技术分享图片技术分享图片

def train_network():
predict
= inference_network()
label
= fluid.layers.data(name=label,shape=[1],dtype=int64)
cost
= fluid.layers.cross_entropy(input=predict,label=label)
avg_cost
= fluid.layers.mean(cost)
accuracy
= fluid.layers.accuracy(input=predict,label=label)
return [avg_cost,accuracy]
def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(data_path,save_path):
BATCH_SIZE
= 128
EPOCH_NUM
= 2
train_reader
= paddle.batch(
paddle.reader.shuffle(reader_creator(data_path),buf_size
=50000),
batch_size
= BATCH_SIZE)
test_reader
= paddle.batch(
reader_creator(data_path,False),
batch_size
=BATCH_SIZE)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
print("
Pass %d, Epoch %d, Cost %f, Acc %f
" %
(event.step, event.epoch, event.metrics[0],
event.metrics[
1]))
else:
sys.stdout.write(
.)
sys.stdout.flush()
if isinstance(event, fluid.EndEpochEvent):
avg_cost, accuracy
= trainer.test(
reader
=test_reader, feed_order=[image, label])
print(
Test with Pass {0}, Loss {1:2.2}, Acc {2:2.2}
.format(
event.epoch, avg_cost, accuracy))
if save_path is not None:
trainer.save_params(save_path)
place
= fluid.CUDAPlace(0)
trainer
= fluid.Trainer(
train_func
=train_network, optimizer_func=optimizer_program, place=place)
trainer.train(
reader
=train_reader,
num_epochs
=EPOCH_NUM,
event_handler
=event_handler,
feed_order
=[image, label])


View Code

4、测试接口

测试接口也类似,需要特别注意的是图像维度要改为[N C H W]的顺序!


技术分享图片技术分享图片

def infer(params_dir):
place
= fluid.CUDAPlace(0)
inferencer
= fluid.Inferencer(
infer_func
=inference_network, param_path=params_dir, place=place)
# Prepare testing data.
from PIL import Image
import numpy as np
import os
def load_image(file):
im
= Image.open(file)
im
= im.resize((32, 32), Image.ANTIALIAS)
im
= np.array(im).astype(np.float32)
"""transpose [H W C] to [C H W]"""
im
= im.transpose((2, 0, 1))
im
= im / 255.0
# Add one dimension, [N C H W] N=1
im = np.expand_dims(im, axis=0)
return im
cur_dir
= os.path.dirname(os.path.realpath(__file__))
img
= load_image(cur_dir + /dog.png)
# inference
results = inferencer.infer({image: img})
print(results)
lab
= np.argsort(results) # probs and lab are the results of one batch data
print("infer results: ", cifar_classes[lab[0][0][-1]])


View Code

5、运行结果

由于笔者没有GPU服务器,所以只迭代了50次,已经用了8个多小时,但是准确率只有15.6%,测试集方面准确率有17%,效果不理想,用于验证的结果也是错的!


Pass 300, Epoch 49, Cost 2.261115, Acc 0.156250
.........................................................................................
Test with Pass
49, Loss 2.2, Acc 0.17
Classify the cifar10 images...
[array([[
0.05997971, 0.13485196, 0.096842 , 0.09973737, 0.11053724,
0.08180068, 0.13847008, 0.08627985, 0.06851784, 0.12298328]],
dtype
=float32)]
infer results: frog

 


结语

网络比较深,且数据集比较大,训练时间比较长,普通笔记本上面的GT840M聊以胜无吧。

 

本文代码:02_cifar

参考:book/03.image_classification/

 


推荐阅读
  • 深入解析轻量级数据库 SQL Server Express LocalDB
    本文详细介绍了 SQL Server Express LocalDB,这是一种轻量级的本地 T-SQL 数据库解决方案,特别适合开发环境使用。文章还探讨了 LocalDB 与其他轻量级数据库的对比,并提供了安装和连接 LocalDB 的步骤。 ... [详细]
  • 详解MyBatis二级缓存的启用与配置
    本文深入探讨了MyBatis二级缓存的启用方法及其配置细节,通过具体的代码实例进行说明,有助于开发者更好地理解和应用这一特性,提升应用程序的性能。 ... [详细]
  • 为帮助编程爱好者更好地掌握Python和Go语言的核心技能,我们特别提供两本精选图书的免费赠阅机会。《易懂的Python算法指南》适合所有希望提高算法理解能力的读者,《Go语言编程从入门到精通》则面向对Go语言感兴趣的初学者及有一定基础的开发者。 ... [详细]
  • 解决远程桌面连接时的身份验证错误问题
    本文介绍了如何解决在尝试远程访问服务器时遇到的身份验证错误,特别是当系统提示‘要求的函数不受支持’时的具体解决步骤。通过调整Windows注册表设置,您可以轻松解决这一常见问题。 ... [详细]
  • iOS 小组件开发指南
    本文详细介绍了iOS小部件(Widget)的开发流程,从环境搭建、证书配置到业务逻辑实现,提供了一系列实用的技术指导与代码示例。 ... [详细]
  • 本文详细介绍了Oracle RMAN中的增量备份机制,重点解析了差异增量和累积增量备份的概念及其在不同Oracle版本中的实现。通过对比两种备份方式的特点,帮助读者选择合适的备份策略。 ... [详细]
  • SQL 数据恢复技巧:利用快照实现高效恢复
    本文详细介绍了如何在 SQL 中通过数据库快照实现数据恢复,包括快照的创建、使用及恢复过程,旨在帮助读者深入了解这一技术并有效应用于实际场景。 ... [详细]
  • 构建Python自助式数据查询系统
    在现代数据密集型环境中,业务团队频繁需要从数据库中提取特定信息。为了提高效率并减少IT部门的工作负担,本文探讨了一种利用Python语言实现的自助数据查询工具的设计与实现。 ... [详细]
  • 本文详细介绍了如何在本地环境中安装配置Frida及其服务器组件,以及如何通过Frida进行基本的应用程序动态分析,包括获取应用版本和加载的类信息。 ... [详细]
  • 使用Pandas DataFrame探索十大城市房价与薪资对比
    在本篇文章中,我们将通过Pandas库中的DataFrame工具,深入了解中国十大城市的房价与薪资水平,探讨哪些城市的生活成本更为合理。这是学习Python数据分析系列的第82篇原创文章,预计阅读时间约为6分钟。 ... [详细]
  • 本文详细解析 Skynet 的启动流程,包括配置文件的读取、环境变量的设置、主要线程的启动(如 timer、socket、monitor 和 worker 线程),以及消息队列的实现机制。 ... [详细]
  • 本文介绍了如何通过编辑特定配置文件来自定义Linux系统中Bash的登录界面以及登录成功后的显示信息,包括本地和远程连接时的提示。 ... [详细]
  • 本文介绍了基于Java的在线办公工作流系统的毕业设计方案,涵盖了MyBatis框架的应用、源代码分析、调试与部署流程、数据库设计以及相关论文撰写指导。 ... [详细]
  • 雨林木风 GHOST XP SP3 经典珍藏版 YN2014.04
    雨林木风 GHOST XP SP3 经典珍藏版 YN2014.04 ... [详细]
  • 探索OpenWrt中的LuCI框架
    本文深入探讨了OpenWrt系统中轻量级HTTP服务器uhttpd的工作原理及其配置,重点介绍了LuCI界面的实现机制。 ... [详细]
author-avatar
捕鱼达人2602914975
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有