热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

手写数字识别卷积神经网络

内容都是百度AIstudio的内容,我只是在这里做个笔记,不是原创。#合并后代码#数据处理部分之前的代码,保持不变importos

内容都是百度AIstudio的内容,我只是在这里做个笔记,不是原创。 

#合并后代码
#数据处理部分之前的代码,保持不变
import os
import random
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
import numpy as np
import matplotlib.pyplot as plt
from PIL import Imageimport gzip
import json#数据处理部分的展开代码
# 定义数据集读取器
def load_data(mode='train'):# 数据文件datafile = './work/mnist.json.gz'print('loading mnist dataset from {} ......'.format(datafile))data = json.load(gzip.open(datafile))# 读取到的数据可以直接区分训练集,验证集,测试集train_set, val_set, eval_set = data# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLSIMG_ROWS = 28IMG_COLS = 28# 获得数据if mode == 'train':imgs = train_set[0]labels = train_set[1]elif mode == 'valid':imgs = val_set[0]labels = val_set[1]elif mode == 'eval':imgs = eval_set[0]labels = eval_set[1]else:raise Exception("mode can only be one of ['train', 'valid', 'eval']")imgs_length = len(imgs)assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))index_list = list(range(imgs_length))# 读入数据时用到的batchsizeBATCHSIZE = 100# 定义数据生成器def data_generator():if mode == 'train':# 训练模式下,将训练数据打乱random.shuffle(index_list)imgs_list = []labels_list = []for i in index_list:img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')label = np.reshape(labels[i], [1]).astype('float32')imgs_list.append(img) labels_list.append(label)if len(imgs_list) == BATCHSIZE:# 产生一个batch的数据并返回yield np.array(imgs_list), np.array(labels_list)# 清空数据读取列表imgs_list = []labels_list = []# 如果剩余数据的数目小于BATCHSIZE,# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batchif len(imgs_list) > 0:yield np.array(imgs_list), np.array(labels_list)return data_generator#数据处理部分之后的代码,数据读取的部分调用Load_data函数
# 定义网络结构,同上一节所使用的网络结构
class MNIST(fluid.dygraph.Layer):def __init__(self, name_scope):super(MNIST, self).__init__(name_scope)name_scope = self.full_name()# 定义卷积层,输出特征通道num_filters设置为20,卷积核的大小filter_size为5,卷积步长stride=1,padding=2# 激活函数使用reluself.conv1 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')# 定义池化层,池化核pool_size=2,池化步长为2,选择最大池化方式self.pool1 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')# 定义卷积层,输出特征通道num_filters设置为20,卷积核的大小filter_size为5,卷积步长stride=1,padding=2self.conv2 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')# 定义池化层,池化核pool_size=2,池化步长为2,选择最大池化方式self.pool2 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')# 定义一层全连接层,输出维度是1,不使用激活函数self.fc = FC(name_scope, size=1, act=None)def forward(self, inputs):# outputs = self.fc(inputs)# return outputsx = self.conv1(inputs)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.fc(x)return x# 训练配置,并启动训练过程
with fluid.dygraph.guard():model = MNIST("mnist")model.train()#调用加载数据的函数train_loader = load_data('train')# 创建异步数据读取器place = fluid.CPUPlace()data_loader = fluid.io.DataLoader.from_generator(capacity=5, return_list=True)data_loader.set_batch_generator(train_loader, places=place)optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001)EPOCH_NUM = 6for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(data_loader):image_data, label_data = dataimage = fluid.dygraph.to_variable(image_data)label = fluid.dygraph.to_variable(label_data)predict = model(image)loss = fluid.layers.square_error_cost(predict, label)avg_loss = fluid.layers.mean(loss)if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))avg_loss.backward()optimizer.minimize(avg_loss)model.clear_gradients()fluid.save_dygraph(model.state_dict(), 'mnist')

 

with fluid.dygraph.guard():print('start evaluation .......')#加载模型参数model = MNIST("mnist")model_state_dict, _ = fluid.load_dygraph('mnist')model.load_dict(model_state_dict)model.eval()eval_loader = load_data('eval')acc_set = []avg_loss_set = []cnt=0sum=0for batch_id, data in enumerate(eval_loader()):x_data, y_data = dataimg = fluid.dygraph.to_variable(x_data)label = fluid.dygraph.to_variable(y_data)prediction= model(img)prediction=prediction.numpy().astype('int32')label=label.numpy().astype('int32')for i in range(len(label)):# print(prediction[i],label[i])if(prediction[i]==label[i]):cnt+=1sum+=100# print(len(prediction))# print("hello:",prediction.numpy().astype('int32'),label.numpy().astype('int32'))# loss = fluid.layers.square_error_cost(input=prediction, label=label)# avg_loss = fluid.layers.mean(loss)# acc_set.append(float(acc.numpy()))# avg_loss_set.append(float(avg_loss.numpy()))#计算多个batch的平均损失和准确率# acc_val_mean = np.array(acc_set).mean()# avg_loss_val_mean = np.array(avg_loss_set).mean()print("acc:",cnt/sum)# print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))

单张图片测试

def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')print(im)im = im.resize((28, 28), Image.ANTIALIAS)im = np.reshape(im,[1,1,28,28]).astype(np.float32)# [batchsize, channels, rows, cols]是这样的形式。1是因为mnist是灰度图只有一个通道,# 一般训练和预测的时候都是指定batchsize大小的,单张测试就取了batchsize=1# 图像归一化,保持和数据集的数据范围一致# print(im)im = 1- im / 127.5 # print(im)img=np.array(im).reshape(28,28).astype(np.float32)plt.imshow(img)return im# 定义预测过程
with fluid.dygraph.guard():model = MNIST("mnist")params_file_path = 'mnist'img_path = './work/example_8.png'# 加载模型参数model_dict, _ = fluid.load_dygraph("mnist")model.load_dict(model_dict)model.eval()tensor_img = load_image(img_path)result = model(fluid.dygraph.to_variable(tensor_img))# #预测输出取整,即为预测的数字print("本次预测的数字是", result.numpy().astype('int32'))

 


推荐阅读
  • voc生成xml 代码
    目录 lxmlwindows安装 读取示例 可视化 生成示例 上面是代码,下面有调用示例 api调用代码,其实只有几行:这个生成代码也很简 ... [详细]
  • 在第七天的深度学习课程中,我们将重点探讨DGL框架的高级应用,特别是在官方文档指导下进行数据集的下载与预处理。通过详细的步骤说明和实用技巧,帮助读者高效地构建和优化图神经网络的数据管道。此外,我们还将介绍如何利用DGL提供的模块化工具,实现数据的快速加载和预处理,以提升模型训练的效率和准确性。 ... [详细]
  • 使用PyQt5与OpenCV实现电脑摄像头的图像捕捉功能
    本文介绍了如何使用Python中的PyQt5和OpenCV库来实现电脑摄像头的图像捕捉功能。通过结合这两个强大的工具,用户可以轻松地打开摄像头并进行实时图像采集和处理。代码示例展示了如何初始化摄像头、捕获图像并将其显示在PyQt5的图形界面中。此外,还提供了详细的步骤说明和代码注释,帮助开发者快速上手并实现相关功能。 ... [详细]
  • 本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。 ... [详细]
  • Go语言实现Redis客户端与服务器的交互机制深入解析
    在前文对Godis v1.0版本的基础功能进行了详细介绍后,本文将重点探讨如何实现客户端与服务器之间的交互机制。通过具体代码实现,使客户端与服务器能够顺利通信,赋予项目实际运行的能力。本文将详细解析Go语言在实现这一过程中的关键技术和实现细节,帮助读者深入了解Redis客户端与服务器的交互原理。 ... [详细]
  • Android目录遍历工具 | AppCrawler自动化测试进阶(第二部分):个性化配置详解
    终于迎来了“足不出户也能为社会贡献力量”的时刻,但有追求的测试工程师绝不会让自己的生活变得乏味。与其在家消磨时光,不如利用这段时间深入研究和提升自己的技术能力,特别是对AppCrawler自动化测试工具的个性化配置进行详细探索。这不仅能够提高测试效率,还能为项目带来更多的价值。 ... [详细]
  • Python正则表达式详解:掌握数量词用法轻松上手
    Python正则表达式详解:掌握数量词用法轻松上手 ... [详细]
  • 本文探讨了如何利用Python的反射机制,高效地将Excel中的数据映射并转换为类对象属性。通过反射技术,可以动态地读取Excel文件中的数据,并将其加载到内存中,转换为相应的类对象,从而方便进行后续的数据处理和操作。该方法适用于需要频繁从Excel导入数据的场景,能够显著提高开发效率和代码可维护性。 ... [详细]
  • 在财务分析与金融数据处理中,利用Python的强大库如NumPy和SciPy可以高效地计算各种财务指标。例如,通过调用这些库中的函数,可以轻松计算货币的时间价值,包括终值(FV)等关键指标。此外,这些库还提供了丰富的统计和数学工具,有助于进行更深入的数据分析和模型构建。 ... [详细]
  • 开发技巧分享:利用套索与矩形选择工具高效选取绘图中的全部字形节点
    开发技巧分享:利用套索与矩形选择工具高效选取绘图中的全部字形节点 ... [详细]
  • 抠图前vsPython自动抠图后在日常的工作和生活中,我们经常会遇到需要抠图的场景,即便是只有一张图片需要抠,也会抠得我们不耐烦ÿ ... [详细]
  • 在Spring Boot项目中,通过YAML配置文件为静态变量设置值的方法与实践涉及以下几个步骤:首先,创建一个新的配置类。需要注意的是,自动生成的setter方法默认是非静态的,因此需要手动将其修改为静态方法,以确保静态变量能够正确初始化。此外,建议使用`@Value`注解或`@ConfigurationProperties`注解来注入配置属性,以提高代码的可读性和维护性。 ... [详细]
  • 在Hive中合理配置Map和Reduce任务的数量对于优化不同场景下的性能至关重要。本文探讨了如何控制Hive任务中的Map数量,分析了当输入数据超过128MB时是否会自动拆分,以及Map数量是否越多越好的问题。通过实际案例和实验数据,本文提供了具体的配置建议,帮助用户在不同场景下实现最佳性能。 ... [详细]
  • 本文详细介绍了图表图例的语法与配置方法,包括如何通过 `loc` 参数设置图例的位置。具体位置选项包括:'best'(自动选择最佳位置)、'upper right'、'upper left'、'lower left' 和 'lower right' 等。此外,还探讨了其他高级配置选项,如图例的字体大小、边框样式和透明度等,以帮助用户更好地定制图表图例。 ... [详细]
  • 本文深入探讨了 C# 中 `SqlCommand` 和 `SqlDataAdapter` 的核心差异及其应用场景。`SqlCommand` 主要用于执行单一的 SQL 命令,并通过 `DataReader` 获取结果,具有较高的执行效率,但灵活性较低。相比之下,`SqlDataAdapter` 则适用于复杂的数据操作,通过 `DataSet` 提供了更多的数据处理功能,如数据填充、更新和批量操作,更适合需要频繁数据交互的场景。 ... [详细]
author-avatar
luhd88112010_254
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有