热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Paddle:手写字符识别模型(1)

快速构建字符识别实验开始之前,先导入本实验所需要的函数库:importpaddleimportpaddle.fluidasfluidfrompa

快速构建字符识别

实验开始之前,先导入本实验所需要的函数库:

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
from PIL import Image

数据集的预处理

paddle.dataset 中提供了很多的数据集集合,如下:


  • mnist
  • cifar
  • Conll05
  • imdb
  • imikolov
  • movielens
  • sentiment
  • uci_housing
  • wmt14
  • wmt16

我们可以通过 paddle.dataset.mnist.train() 加载训练数据,通过 paddle.batch 来分批:

# 如果~/.cache/paddle/dataset/mnist/目录下没有MNIST数据,API会自动将MINST数据下载到该文件夹下
# 设置数据读取器,读取MNIST数据训练集
trainset = paddle.dataset.mnist.train()
# 包装数据读取器,每次读取的数据数量设置为batch_size=8
train_reader = paddle.batch(trainset, batch_size=8)

接下来,将数据中的 img 和 label 分开。这里使用迭代的方式,一批次一批次的分开:

# 以迭代的形式读取数据
for batch_id, data in enumerate(train_reader()):# 获得图像数据,并转为float32类型的数组img_data = np.array([x[0] for x in data]).astype('float32')# 获得图像标签数据,并转为float32类型的数组label_data = np.array([x[1] for x in data]).astype('float32')# 打印数据形状print("图像数据形状和对应数据为:", img_data.shape)print("图像标签形状和对应数据为:", label_data.shape)break

结果:
在这里插入图片描述
我们可以展示一下训练数据:

print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))
# 显示第一batch的第一个图像
import matplotlib.pyplot as plt
# 原始数据是归一化后的数据,因此这里需要反归一化
img = np.array(img_data[0]+1)*127.5
img = np.reshape(img, [28, 28]).astype(np.uint8)plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()

结果:
在这里插入图片描述
从打印结果看,从数据加载器train_loader()中读取一次数据,可以得到形状为(8, 784)的图像数据和形状为(8,)的标签数据。其中,形状中的数字8与设置的batch_size大小对应,784为MINIST数据集中每个图像的像素大小(28*28)。


识别模型


模型的建立

作为入门课程,这里使用最简单的线性网络。和 PyTorch 定义模型的方法类似,不过这里需要继承的是 fluid.dygraph.Layer

class minist_model(fluid.dygraph.Layer):def __init__(self):super(minist_model,self).__init__()# 定义一层全连接层,输出维度是1,激活函数为None,即不使用激活函数self.fc = Linear(input_dim=28*28,output_dim=1,act=None)def forward(self,inputs):outputs = self.fc(inputs)return outputs
model = minist_model()
model

结果:
在这里插入图片描述


模型的配置

# 开启所有变量的梯度计算
with fluid.dygraph.guard():model = minist_model()model.train()#开启模型训练的模式#定义数据加载器train_loader = paddle.batch(paddle.dataset.mnist.train(),batch_size=16)# 定义优化器opt = fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())
opt

结果:
在这里插入图片描述


模型的训练

## 模型训练
# 开启所有变量的梯度计算
with fluid.dygraph.guard():model = minist_model()model.train()#开启模型训练的模式#定义数据加载器train_loader = paddle.batch(paddle.dataset.mnist.train(),batch_size=16)# 定义优化器opt = fluid.optimizer.SGDOptimizer(learning_rate=0.001,parameter_list=model.parameters())for epoch_id in range(100):for batch_id,data in enumerate(train_loader()):#划分input 和 outputimage_data = np.array([x[0] for x in data]).astype('float32')label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1)# 将数据转为 飞浆动态图格式(Tensor)image = fluid.dygraph.to_variable(image_data)label = fluid.dygraph.to_variable(label_data)pre = model(image)#平均平方差损失loss = fluid.layers.square_error_cost(pre,label)avg_loss = fluid.layers.mean(loss)if batch_id !=0 and batch_id %1000==0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))avg_loss.backward()opt.minimize(avg_loss)model.clear_gradients()
# 保存模型
fluid.save_dygraph(model.state_dict(),'mnist')

结果:
在这里插入图片描述


模型的测试

首先加载一张新的图像:

# 导入图像读取第三方库```python
import matplotlib.image as Img
import matplotlib.pyplot as plt
# 读取图像
example = Img.imread('./work/example_0.png')
# 显示图像
plt.imshow(example)
plt.show()

结果:
在这里插入图片描述
相对图像进行预处理,然后进行预测:

# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.ANTIALIAS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 127.5return im# 定义预测过程
with fluid.dygraph.guard():model = MNIST()params_file_path = 'mnist'img_path = './work/example_0.png'
# 加载模型参数model_dict, _ = fluid.load_dygraph("mnist")model.load_dict(model_dict)
# 灌入数据model.eval()#启动模型评价# 将一张图片转为一行向量tensor_img = load_image(img_path)print("数据集的大小为:",tensor_img.shape)result = model(fluid.dygraph.to_variable(tensor_img))
# 预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))

结果:

在这里插入图片描述
由于上面用的是线性网络,所以得到结果不尽人意。我们可以查看一下该模型的模型准确率:

# 通过with语句创建一个dygraph运行的context,
# 动态图下的一些操作需要在guard下进行
correct = 0
count = 0
with fluid.dygraph.guard():model = MNIST()# 加载模型参数model_dict, _ = fluid.load_dygraph("mnist")model.load_dict(model_dict)test_loader = paddle.batch(paddle.dataset.mnist.test(), batch_size=16)for batch_id, data in enumerate(test_loader()):#准备数据,格式需要转换成符合框架要求的image_data = np.array([x[0] for x in data]).astype('float32')label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1)# 将数据转为飞桨动态图格式image = fluid.dygraph.to_variable(image_data)label = fluid.dygraph.to_variable(label_data)model.eval()#启动模型评价#前向计算的过程predict = model(image)pre = predict.numpy().astype('int32') correct=correct+np.sum(pre==label_data)count = count+len(image_data)
print(f"正确率为:{correct/count*100:.2f}%")

结果:
在这里插入图片描述


推荐阅读
  • voc生成xml 代码
    目录 lxmlwindows安装 读取示例 可视化 生成示例 上面是代码,下面有调用示例 api调用代码,其实只有几行:这个生成代码也很简 ... [详细]
  • 在第七天的深度学习课程中,我们将重点探讨DGL框架的高级应用,特别是在官方文档指导下进行数据集的下载与预处理。通过详细的步骤说明和实用技巧,帮助读者高效地构建和优化图神经网络的数据管道。此外,我们还将介绍如何利用DGL提供的模块化工具,实现数据的快速加载和预处理,以提升模型训练的效率和准确性。 ... [详细]
  • Node.js 教程第五讲:深入解析 EventEmitter(事件监听与发射机制)
    本文将深入探讨 Node.js 中的 EventEmitter 模块,详细介绍其在事件监听与发射机制中的应用。内容涵盖事件驱动的基本概念、如何在 Node.js 中注册和触发自定义事件,以及 EventEmitter 的核心 API 和使用方法。通过本教程,读者将能够全面理解并熟练运用 EventEmitter 进行高效的事件处理。 ... [详细]
  • 深入解析 UIImageView 与 UIImage 的关键细节与应用技巧
    本文深入探讨了 UIImageView 和 UIImage 的核心特性及应用技巧。首先,详细介绍了如何在 UIImageView 中实现动画效果,包括创建和配置 UIImageView 实例的具体步骤。此外,还探讨了 UIImage 的加载方式及其对性能的影响,提供了优化图像显示和内存管理的有效方法。通过实例代码和实际应用场景,帮助开发者更好地理解和掌握这两个重要类的使用技巧。 ... [详细]
  • 本文介绍了Android动画的基本概念及其主要类型。Android动画主要包括三种形式:视图动画(也称为补间动画或Tween动画),主要通过改变视图的属性来实现动态效果;帧动画,通过顺序播放一系列预定义的图像来模拟动画效果;以及属性动画,通过对对象的属性进行平滑过渡来创建更加复杂的动画效果。每种类型的动画都有其独特的应用场景和实现方式,开发者可以根据具体需求选择合适的动画类型。 ... [详细]
  • 深入解析 Django 中用户模型的自定义方法与技巧 ... [详细]
  • 深入解析Tomcat:开发者的实用指南
    深入解析Tomcat:开发者的实用指南 ... [详细]
  • Go语言实现Redis客户端与服务器的交互机制深入解析
    在前文对Godis v1.0版本的基础功能进行了详细介绍后,本文将重点探讨如何实现客户端与服务器之间的交互机制。通过具体代码实现,使客户端与服务器能够顺利通信,赋予项目实际运行的能力。本文将详细解析Go语言在实现这一过程中的关键技术和实现细节,帮助读者深入了解Redis客户端与服务器的交互原理。 ... [详细]
  • 在Python编程中,掌握高级技巧对于提升代码效率和可读性至关重要。本文重点探讨了生成器和迭代器的应用,这两种工具不仅能够优化内存使用,还能简化复杂数据处理流程。生成器通过按需生成数据,避免了大量数据加载对内存的占用,而迭代器则提供了一种优雅的方式来遍历集合对象。此外,文章还深入解析了这些高级特性的实际应用场景,帮助读者更好地理解和运用这些技术。 ... [详细]
  • 深入学习 Python 中的 xlrd 模块:掌握 Excel 文件读取技巧
    本文深入探讨了 Python 中的 xlrd 模块,重点介绍了如何高效读取 Excel 文件(包括 xlsx 和 xls 格式)。同时,文章还详细讲解了 xlwt 模块在 Excel 文件写操作中的应用。此外,文中列举了常见单元格数据类型及其处理方法,为读者提供了全面的实践指导。 ... [详细]
  • 题目编号为547的“朋友圈”问题属于中等难度。该问题描述了班级中有N名学生,部分学生之间存在友谊关系,且这种友谊关系具有传递性。即如果A和B是朋友,B和C也是朋友,那么A和C同样被视为朋友。本文将通过Python语言提供一种高效的解决方案,详细探讨如何利用图论中的并查集算法来快速计算出班级中所有互为朋友的学生群体数量。 ... [详细]
  • 利用GDAL库在Python中高效读取与处理栅格数据的详细指南 ... [详细]
  • Python正则表达式详解:掌握数量词用法轻松上手
    Python正则表达式详解:掌握数量词用法轻松上手 ... [详细]
  • 在上一节中,我们完成了网络的前向传播实现。本节将重点探讨如何为检测输出设定目标置信度阈值,并应用非极大值抑制技术以提高检测精度。为了更好地理解和实践这些内容,建议读者已经完成本系列教程的前三部分,并具备一定的PyTorch基础知识。此外,我们将详细介绍这些技术的原理及其在实际应用中的重要性,帮助读者深入理解目标检测算法的核心机制。 ... [详细]
  • Jedis接口分类详解与应用指南
    本文详细解析了Jedis接口的分类及其应用指南,重点介绍了字符串数据类型(String)的接口功能。作为Redis中最基本的数据存储形式,字符串类型支持多种操作,如设置、获取和更新键值对等,适用于广泛的应用场景。 ... [详细]
author-avatar
apiaoapiao_622
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有