热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

手写数字识别卷积神经网络

内容都是百度AIstudio的内容,我只是在这里做个笔记,不是原创。#合并后代码#数据处理部分之前的代码,保持不变importos

内容都是百度AIstudio的内容,我只是在这里做个笔记,不是原创。 

#合并后代码
#数据处理部分之前的代码,保持不变
import os
import random
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
import numpy as np
import matplotlib.pyplot as plt
from PIL import Imageimport gzip
import json#数据处理部分的展开代码
# 定义数据集读取器
def load_data(mode='train'):# 数据文件datafile = './work/mnist.json.gz'print('loading mnist dataset from {} ......'.format(datafile))data = json.load(gzip.open(datafile))# 读取到的数据可以直接区分训练集,验证集,测试集train_set, val_set, eval_set = data# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLSIMG_ROWS = 28IMG_COLS = 28# 获得数据if mode == 'train':imgs = train_set[0]labels = train_set[1]elif mode == 'valid':imgs = val_set[0]labels = val_set[1]elif mode == 'eval':imgs = eval_set[0]labels = eval_set[1]else:raise Exception("mode can only be one of ['train', 'valid', 'eval']")imgs_length = len(imgs)assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))index_list = list(range(imgs_length))# 读入数据时用到的batchsizeBATCHSIZE = 100# 定义数据生成器def data_generator():if mode == 'train':# 训练模式下,将训练数据打乱random.shuffle(index_list)imgs_list = []labels_list = []for i in index_list:img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')label = np.reshape(labels[i], [1]).astype('float32')imgs_list.append(img) labels_list.append(label)if len(imgs_list) == BATCHSIZE:# 产生一个batch的数据并返回yield np.array(imgs_list), np.array(labels_list)# 清空数据读取列表imgs_list = []labels_list = []# 如果剩余数据的数目小于BATCHSIZE,# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batchif len(imgs_list) > 0:yield np.array(imgs_list), np.array(labels_list)return data_generator#数据处理部分之后的代码,数据读取的部分调用Load_data函数
# 定义网络结构,同上一节所使用的网络结构
class MNIST(fluid.dygraph.Layer):def __init__(self, name_scope):super(MNIST, self).__init__(name_scope)name_scope = self.full_name()# 定义卷积层,输出特征通道num_filters设置为20,卷积核的大小filter_size为5,卷积步长stride=1,padding=2# 激活函数使用reluself.conv1 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')# 定义池化层,池化核pool_size=2,池化步长为2,选择最大池化方式self.pool1 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')# 定义卷积层,输出特征通道num_filters设置为20,卷积核的大小filter_size为5,卷积步长stride=1,padding=2self.conv2 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')# 定义池化层,池化核pool_size=2,池化步长为2,选择最大池化方式self.pool2 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')# 定义一层全连接层,输出维度是1,不使用激活函数self.fc = FC(name_scope, size=1, act=None)def forward(self, inputs):# outputs = self.fc(inputs)# return outputsx = self.conv1(inputs)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.fc(x)return x# 训练配置,并启动训练过程
with fluid.dygraph.guard():model = MNIST("mnist")model.train()#调用加载数据的函数train_loader = load_data('train')# 创建异步数据读取器place = fluid.CPUPlace()data_loader = fluid.io.DataLoader.from_generator(capacity=5, return_list=True)data_loader.set_batch_generator(train_loader, places=place)optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001)EPOCH_NUM = 6for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(data_loader):image_data, label_data = dataimage = fluid.dygraph.to_variable(image_data)label = fluid.dygraph.to_variable(label_data)predict = model(image)loss = fluid.layers.square_error_cost(predict, label)avg_loss = fluid.layers.mean(loss)if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))avg_loss.backward()optimizer.minimize(avg_loss)model.clear_gradients()fluid.save_dygraph(model.state_dict(), 'mnist')

 

with fluid.dygraph.guard():print('start evaluation .......')#加载模型参数model = MNIST("mnist")model_state_dict, _ = fluid.load_dygraph('mnist')model.load_dict(model_state_dict)model.eval()eval_loader = load_data('eval')acc_set = []avg_loss_set = []cnt=0sum=0for batch_id, data in enumerate(eval_loader()):x_data, y_data = dataimg = fluid.dygraph.to_variable(x_data)label = fluid.dygraph.to_variable(y_data)prediction= model(img)prediction=prediction.numpy().astype('int32')label=label.numpy().astype('int32')for i in range(len(label)):# print(prediction[i],label[i])if(prediction[i]==label[i]):cnt+=1sum+=100# print(len(prediction))# print("hello:",prediction.numpy().astype('int32'),label.numpy().astype('int32'))# loss = fluid.layers.square_error_cost(input=prediction, label=label)# avg_loss = fluid.layers.mean(loss)# acc_set.append(float(acc.numpy()))# avg_loss_set.append(float(avg_loss.numpy()))#计算多个batch的平均损失和准确率# acc_val_mean = np.array(acc_set).mean()# avg_loss_val_mean = np.array(avg_loss_set).mean()print("acc:",cnt/sum)# print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))

单张图片测试

def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')print(im)im = im.resize((28, 28), Image.ANTIALIAS)im = np.reshape(im,[1,1,28,28]).astype(np.float32)# [batchsize, channels, rows, cols]是这样的形式。1是因为mnist是灰度图只有一个通道,# 一般训练和预测的时候都是指定batchsize大小的,单张测试就取了batchsize=1# 图像归一化,保持和数据集的数据范围一致# print(im)im = 1- im / 127.5 # print(im)img=np.array(im).reshape(28,28).astype(np.float32)plt.imshow(img)return im# 定义预测过程
with fluid.dygraph.guard():model = MNIST("mnist")params_file_path = 'mnist'img_path = './work/example_8.png'# 加载模型参数model_dict, _ = fluid.load_dygraph("mnist")model.load_dict(model_dict)model.eval()tensor_img = load_image(img_path)result = model(fluid.dygraph.to_variable(tensor_img))# #预测输出取整,即为预测的数字print("本次预测的数字是", result.numpy().astype('int32'))

 


推荐阅读
  • 本文详细介绍了如何使用 Yii2 的 GridView 组件在列表页面实现数据的直接编辑功能。通过具体的代码示例和步骤,帮助开发者快速掌握这一实用技巧。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 主要用了2个类来实现的,话不多说,直接看运行结果,然后在奉上源代码1.Index.javaimportjava.awt.Color;im ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • 本文详细介绍了如何构建一个高效的UI管理系统,集中处理UI页面的打开、关闭、层级管理和页面跳转等问题。通过UIManager统一管理外部切换逻辑,实现功能逻辑分散化和代码复用,支持多人协作开发。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 将Web服务部署到Tomcat
    本文介绍了如何在JDeveloper 12c中创建一个Java项目,并将其打包为Web服务,然后部署到Tomcat服务器。内容涵盖从项目创建、编写Web服务代码、配置相关XML文件到最终的本地部署和验证。 ... [详细]
author-avatar
luhd88112010_254
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有