热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于TensorFlow的Keras高级API实现手写体数字识别

前言这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为

前言

这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为大家提供代码哈哈哈哈


一、Keras?


1.Keras简介

Keras是由纯python编写的基于theano/tensorflow的深度学习框架。 Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结果,如果有如下需求,可以优先选择Keras。


2.为什么

目前Keras已经被TensorFlow收录,添加到TensorFlow 中,成为其默认的框架,成为TensorFlow官方的高级API。Keras简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性),用户友好:Keras是为人类而不是天顶星人设计的API。用户的使用体验始终是我们考虑的首要和中心内容。Keras遵循减少认知困难的最佳实践:Keras提供一致而简洁的API, 能够极大减少一般应用下用户的工作量,同时,Keras提供清晰和具有实践意义的bug反馈。


二、全连接神经网络实现


1.思路

导入数据-------> 选择模型------>设计神经网络------->编译------->训练权重参数------->预测


2.实现代码

定义函数 train() 实现(导入数据———>训练权重参数)。
定义函数 text() 实现 预测及输出结果。

导入数据:mnist = tf.keras.datasets.mnist #导入mnist
选择模型:model = tf.keras.models.Sequential()
有两种类型的模型,序贯模型(Sequential)和函数式模型(Model),函数式模型应用更为广泛,序贯模型是函数式模型的一种特殊情况。
序贯模型(Sequential) :单输入单输出,一条路通到底,层与层之间只有相邻关系,没有跨层连接。这种模型编译速度快,操作也比较简单;

设计神经网络:

tf.keras.layers.Flatten(input_shape=(28,28)),tf.keras.layers.Dense(512,activation='relu'),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())

编译:

model.compile(optimizer = 优化器,loss = 损失函数,metrics = ["准确率”]')

训练权重参数:

history = model.fit(x_train,y_train,batch_size=每次训练图片数量,epochs=训练次数,
validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()

train函数全部代码

def train():mnist = tf.keras.datasets.mnist #导入mnist(x_train,y_train),(x_test,y_test) = mnist.load_data() #分割x_train,x_test =x_train/255.0, x_test/255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28,28)),tf.keras.layers.Dense(512,activation='relu'),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())])model.compile(optimizer= 'adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])#评价指标 categorical_accuracy和 sparse_categorical_accuracy#注意修改路径checkpoint_save_path="C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt"if os.path.exists(checkpoint_save_path + '.index'):print('------load the model--------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)#断点续训history = model.fit(x_train,y_train,batch_size=25,epochs=30,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])model.summary()#以下为打印训练准确率及损失率等acc = history.history['sparse_categorical_accuracy']val_acc = history.history['val_sparse_categorical_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']f = Figure(figsize=(6,6),dpi=60)a = f.add_subplot(1,2,1)a.plot(acc,label = 'Training Accuracy')a.plot(val_acc,label = 'Validation Accuracy')#验证精度a.legend() b = f.add_subplot(1,2,2)b.plot(loss,label = 'Training Loss')b.plot(val_loss,label = 'Validation Loss')b.legend() canvas = FigureCanvasTkAgg(f,master=root)canvas.draw()canvas.get_tk_widget().place(x=60,y=100)

test函数全部代码

#预测结果打印
def text():#注意修改路径与函数train上面保存的路径一致model_save_path &#61; "C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt"model &#61; tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape&#61;(28,28)),tf.keras.layers.Dense(512,activation&#61;&#39;relu&#39;),tf.keras.layers.Dense(128,activation&#61;&#39;relu&#39;),tf.keras.layers.Dense(10,activation&#61;&#39;softmax&#39;,kernel_regularizer&#61;tf.keras.regularizers.l2())])model.load_weights(model_save_path)for i in range(1):img &#61; Image.open("tem2.png")#强制压缩为28&#xff0c;28img &#61; img.resize((28,28),Image.ANTIALIAS)#将原有图像转换为灰度图img_arr &#61; np.array(img.convert("L"))#图片反相for i in range(28):for j in range(28):if img_arr[i][j]<100:img_arr[i][j]&#61;255else:img_arr[i][j]&#61; 0img_arr &#61; img_arr/255.0x_predict &#61; img_arr[tf.newaxis,...]result &#61; model.predict(x_predict)pred &#61; np.argmax(result , axis &#61; 1)#在GUI界面显示结果e4 &#61; l &#61; tk.Label(root,text &#61; pred, bg&#61;"white",font&#61;("Arial,12"),width&#61;8)e4.place(x&#61;990,y&#61;440)

三、GUI设计

这部分我直接附上代码并在代码中作必要的注释。

全部所需的库函数&#xff1a;

#使用Tkinter前需要先导入
import tkinter as tk
#导入对话框模块
import tkinter.filedialog
#创建画布需要的库
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
#创建工具栏需要的库
from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk
#快捷键需要的模块2
from matplotlib.backend_bases import key_press_handler
#导入绘图需要的模块
from matplotlib.figure import Figure
import cv2
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image,ImageTk

其他关于图片文件的导入及摄像头调用的函数定义代码&#xff1a;

#调取摄像头并拍摄图片
def buttonl():capture &#61; cv2.VideoCapture(0) #cv2模块调取摄像头while(capture.isOpened()):ret,frame &#61; capture.read() #ret表示捕获是否成功frame &#61; frame[:,80:560] #拍照默认为640*480cv2.imwrite("tem1.png",frame)dig_Gray &#61; cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)ref2,dig_Gray &#61; cv2.threshold(dig_Gray,100,255,cv2.THRESH_BINARY)cv2.imwrite("tem2.png",dig_Gray)breakglobal photo1,photo2#将图片显示到界面上img1 &#61; Image.open("tem1.png")img1 &#61; img1.resize((128,128))photo1 &#61; ImageTk.PhotoImage(img1)l1 &#61; tk.Label(root,bg&#61;"red",image &#61; photo1).place(x&#61;950,y&#61;100)img2 &#61; Image.open("tem2.png")img2 &#61; img2.resize((128,128))photo2 &#61; ImageTk.PhotoImage(img2)l2 &#61; tk.Label(root,bg&#61;"red",image &#61; photo2).place(x&#61;950,y&#61;250)#保存当前摄像头画面
def frame():capture &#61; cv2.VideoCapture(0)#控件定义while(capture.isOpened()):ref,frame &#61; capture.read()frame &#61; frame[:,80:560]cvimage &#61; cv2.cvtColor(frame,cv2.COLOR_BGR2RGBA)pilImage &#61; Image.fromarray(cvimage)pilImage &#61; pilImage.resize((360,360),Image.ANTIALIAS)tkImage &#61; ImageTk.PhotoImage(image &#61; pilImage)canvas.create_image(0,0,anchor &#61; "nw",image &#61; tkImage)root.update()root.after(10)
#选择文件
def select_pic():file_path &#61; tk.filedialog.askopenfilename(title&#61;"选择文件",initialdir &#61; (os.path.expanduser(r"")))image &#61; Image.open(file_path)image.save("tem1.png")gray &#61; image.convert("L")gray.save("tem2.png")global photo3,photo4
#将图片显示在界面上img3 &#61; Image.open("tem1.png")img3 &#61; image.resize((128,128))photo3 &#61; ImageTk.PhotoImage(img3)l3 &#61; tk.Label(root,bg&#61;"red",image &#61; photo3).place(x&#61;950,y&#61;100)img4 &#61; Image.open("tem2.png")img4 &#61; img4.resize((128,128))photo4 &#61; ImageTk.PhotoImage(img4)l4 &#61; tk.Label(root,bg&#61;"red",image &#61; photo4).place(x&#61;950,y&#61;250)

主函数部分&#xff1a;

if __name__ &#61;&#61;&#39;__main__&#39;:root &#61; tk.Tk()#第二步&#xff0c;给窗口的可视化起名字root.title(&#39;手写体数字识别&#39;)#第三步&#xff0c;设定窗口的大小&#xff08;长*宽&#xff09;root.geometry(&#39;1176x520&#39;) #这里的乘是小xroot.configure(bg &#61; "#C0C0C0")f &#61; Figure(figsize&#61;(6,6), dpi&#61;60)a&#61;f.add_subplot(1,2,1) #添加子图&#xff1a;1行1列第一个a.plot(0,0)b&#61;f.add_subplot(1,2,2) #添加子图&#xff0c;1行1列第二个b.plot(0,0)#将绘制的图形显示到tkinter&#xff1a;创建属于root的canvas画布&#xff0c;并将图f置于画布上 canvas&#61;FigureCanvasTkAgg(f,master&#61;root)canvas.draw()#注意show方法已经过时&#xff0c;改用drawcanvas.get_tk_widget().place(x&#61;60,y&#61;100)b1 &#61; tk.Button(root,text&#61;&#39;训练&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;train).place(x&#61;168,y&#61;35)b2 &#61; tk.Button(root,text&#61;&#39;拍照&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;frame).place(x&#61;550,y&#61;35)b3 &#61; tk.Button(root,text&#61;&#39;测试&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;text).place(x&#61;960,y&#61;35)b4 &#61; tk.Button(root,text&#61;&#39;导入图片&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;select_pic).place(x&#61;680,y&#61;35)b5 &#61; tk.Button(root,text&#61;&#39;识别结果&#39;,font&#61;(&#39;Arial&#39;,12),bg&#61;&#39;white&#39;).place(x&#61;990,y&#61;400)canvas&#61;tk.Canvas(root,bg&#61;"white",width&#61;360,height&#61;360) #绘制画布#控件位置设置canvas.place(x&#61;500,y&#61;100)b6&#61;tk.Button(root,text&#61;"保存",bg&#61;"white",width&#61;15,height&#61;2,command&#61;buttonl).place(x&#61;620,y&#61;420)#第六步&#xff0c;主窗口循环显示root.mainloop()

最后附上界面

在这里插入图片描述


推荐阅读
author-avatar
徐涛
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有