前言 这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为大家提供代码哈哈哈哈 。
一、Keras? 1.Keras简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架。 Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结果,如果有如下需求,可以优先选择Keras。
2.为什么 目前Keras已经被TensorFlow收录,添加到TensorFlow 中,成为其默认的框架,成为TensorFlow官方的高级API。Keras简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性),用户友好:Keras是为人类而不是天顶星人设计的API。用户的使用体验始终是我们考虑的首要和中心内容。Keras遵循减少认知困难的最佳实践:Keras提供一致而简洁的API, 能够极大减少一般应用下用户的工作量,同时,Keras提供清晰和具有实践意义的bug反馈。
二、全连接神经网络实现 1.思路 导入数据- ------> 选择模型 ------>设计神经网络 ------->编译 ------->训练权重参数 ------->预测
2.实现代码 定义函数 train() 实现(导入数据 ———>训练权重参数 )。 定义函数 text() 实现 预测 及输出结果。
导入数据 :mnist = tf.keras.datasets.mnist #导入mnist
选择模型 :model = tf.keras.models.Sequential()
有两种类型的模型,序贯模型(Sequential)和函数式模型(Model),函数式模型应用更为广泛,序贯模型是函数式模型的一种特殊情况。 序贯模型(Sequential) :单输入单输出,一条路通到底,层与层之间只有相邻关系,没有跨层连接。这种模型编译速度快,操作也比较简单;
设计神经网络 :
tf. keras. layers. Flatten( input_shape= ( 28 , 28 ) ) , tf. keras. layers. Dense( 512 , activation= 'relu' ) , tf. keras. layers. Dense( 128 , activation= 'relu' ) , tf. keras. layers. Dense( 10 , activation= 'softmax' , kernel_regularizer= tf. keras. regularizers. l2( ) )
编译 :
model. compile ( optimizer = 优化器,loss = 损失函数,metrics = [ "准确率”] ')
训练权重参数:
history = model. fit( x_train, y_train, batch_size= 每次训练图片数量, epochs= 训练次数, validation_data= ( x_test, y_test) , validation_freq= 1 , callbacks= [ cp_callback] ) model. summary( )
train函数全部代码
def train ( ) : mnist = tf. keras. datasets. mnist ( x_train, y_train) , ( x_test, y_test) = mnist. load_data( ) x_train, x_test = x_train/ 255.0 , x_test/ 255.0 model = tf. keras. models. Sequential( [ tf. keras. layers. Flatten( input_shape= ( 28 , 28 ) ) , tf. keras. layers. Dense( 512 , activation= 'relu' ) , tf. keras. layers. Dense( 128 , activation= 'relu' ) , tf. keras. layers. Dense( 10 , activation= 'softmax' , kernel_regularizer= tf. keras. regularizers. l2( ) ) ] ) model. compile ( optimizer= 'adam' , loss= tf. keras. losses. SparseCategoricalCrossentropy( from_logits= False ) , metrics= [ 'sparse_categorical_accuracy' ] ) checkpoint_save_path= "C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt" if os. path. exists( checkpoint_save_path + '.index' ) : print ( '------load the model--------' ) model. load_weights( checkpoint_save_path) cp_callback = tf. keras. callbacks. ModelCheckpoint( filepath= checkpoint_save_path, save_weights_only= True , save_best_only= True ) history = model. fit( x_train, y_train, batch_size= 25 , epochs= 30 , validation_data= ( x_test, y_test) , validation_freq= 1 , callbacks= [ cp_callback] ) model. summary( ) acc = history. history[ 'sparse_categorical_accuracy' ] val_acc = history. history[ 'val_sparse_categorical_accuracy' ] loss = history. history[ 'loss' ] val_loss = history. history[ 'val_loss' ] f = Figure( figsize= ( 6 , 6 ) , dpi= 60 ) a = f. add_subplot( 1 , 2 , 1 ) a. plot( acc, label = 'Training Accuracy' ) a. plot( val_acc, label = 'Validation Accuracy' ) a. legend( ) b = f. add_subplot( 1 , 2 , 2 ) b. plot( loss, label = 'Training Loss' ) b. plot( val_loss, label = 'Validation Loss' ) b. legend( ) canvas = FigureCanvasTkAgg( f, master= root) canvas. draw( ) canvas. get_tk_widget( ) . place( x= 60 , y= 100 )
test函数全部代码
def text ( ) : model_save_path &#61; "C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt" model &#61; tf. keras. models. Sequential( [ tf. keras. layers. Flatten( input_shape&#61; ( 28 , 28 ) ) , tf. keras. layers. Dense( 512 , activation&#61; &#39;relu&#39; ) , tf. keras. layers. Dense( 128 , activation&#61; &#39;relu&#39; ) , tf. keras. layers. Dense( 10 , activation&#61; &#39;softmax&#39; , kernel_regularizer&#61; tf. keras. regularizers. l2( ) ) ] ) model. load_weights( model_save_path) for i in range ( 1 ) : img &#61; Image. open ( "tem2.png" ) img &#61; img. resize( ( 28 , 28 ) , Image. ANTIALIAS) img_arr &#61; np. array( img. convert( "L" ) ) for i in range ( 28 ) : for j in range ( 28 ) : if img_arr[ i] [ j] < 100 : img_arr[ i] [ j] &#61; 255 else : img_arr[ i] [ j] &#61; 0 img_arr &#61; img_arr/ 255.0 x_predict &#61; img_arr[ tf. newaxis, . . . ] result &#61; model. predict( x_predict) pred &#61; np. argmax( result , axis &#61; 1 ) e4 &#61; l &#61; tk. Label( root, text &#61; pred, bg&#61; "white" , font&#61; ( "Arial,12" ) , width&#61; 8 ) e4. place( x&#61; 990 , y&#61; 440 )
三、GUI设计 这部分我直接附上代码并在代码中作必要的注释。
全部所需的库函数&#xff1a;
import tkinter as tkimport tkinter. filedialogfrom matplotlib. backends. backend_tkagg import FigureCanvasTkAggfrom matplotlib. backends. backend_tkagg import NavigationToolbar2Tkfrom matplotlib. backend_bases import key_press_handlerfrom matplotlib. figure import Figureimport cv2import tensorflow as tfimport osimport numpy as npfrom matplotlib import pyplot as pltfrom PIL import Image, ImageTk
其他关于图片文件的导入及摄像头调用的函数定义代码&#xff1a;
def buttonl ( ) : capture &#61; cv2. VideoCapture( 0 ) while ( capture. isOpened( ) ) : ret, frame &#61; capture. read( ) frame &#61; frame[ : , 80 : 560 ] cv2. imwrite( "tem1.png" , frame) dig_Gray &#61; cv2. cvtColor( frame, cv2. COLOR_BGR2GRAY) ref2, dig_Gray &#61; cv2. threshold( dig_Gray, 100 , 255 , cv2. THRESH_BINARY) cv2. imwrite( "tem2.png" , dig_Gray) break global photo1, photo2img1 &#61; Image. open ( "tem1.png" ) img1 &#61; img1. resize( ( 128 , 128 ) ) photo1 &#61; ImageTk. PhotoImage( img1) l1 &#61; tk. Label( root, bg&#61; "red" , image &#61; photo1) . place( x&#61; 950 , y&#61; 100 ) img2 &#61; Image. open ( "tem2.png" ) img2 &#61; img2. resize( ( 128 , 128 ) ) photo2 &#61; ImageTk. PhotoImage( img2) l2 &#61; tk. Label( root, bg&#61; "red" , image &#61; photo2) . place( x&#61; 950 , y&#61; 250 ) def frame ( ) : capture &#61; cv2. VideoCapture( 0 ) while ( capture. isOpened( ) ) : ref, frame &#61; capture. read( ) frame &#61; frame[ : , 80 : 560 ] cvimage &#61; cv2. cvtColor( frame, cv2. COLOR_BGR2RGBA) pilImage &#61; Image. fromarray( cvimage) pilImage &#61; pilImage. resize( ( 360 , 360 ) , Image. ANTIALIAS) tkImage &#61; ImageTk. PhotoImage( image &#61; pilImage) canvas. create_image( 0 , 0 , anchor &#61; "nw" , image &#61; tkImage) root. update( ) root. after( 10 ) def select_pic ( ) : file_path &#61; tk. filedialog. askopenfilename( title&#61; "选择文件" , initialdir &#61; ( os. path. expanduser( r"" ) ) ) image &#61; Image. open ( file_path) image. save( "tem1.png" ) gray &#61; image. convert( "L" ) gray. save( "tem2.png" ) global photo3, photo4 img3 &#61; Image. open ( "tem1.png" ) img3 &#61; image. resize( ( 128 , 128 ) ) photo3 &#61; ImageTk. PhotoImage( img3) l3 &#61; tk. Label( root, bg&#61; "red" , image &#61; photo3) . place( x&#61; 950 , y&#61; 100 ) img4 &#61; Image. open ( "tem2.png" ) img4 &#61; img4. resize( ( 128 , 128 ) ) photo4 &#61; ImageTk. PhotoImage( img4) l4 &#61; tk. Label( root, bg&#61; "red" , image &#61; photo4) . place( x&#61; 950 , y&#61; 250 )
主函数部分&#xff1a;
if __name__ &#61;&#61; &#39;__main__&#39; : root &#61; tk. Tk( ) root. title( &#39;手写体数字识别&#39; ) root. geometry( &#39;1176x520&#39; ) root. configure( bg &#61; "#C0C0C0" ) f &#61; Figure( figsize&#61; ( 6 , 6 ) , dpi&#61; 60 ) a&#61; f. add_subplot( 1 , 2 , 1 ) a. plot( 0 , 0 ) b&#61; f. add_subplot( 1 , 2 , 2 ) b. plot( 0 , 0 ) canvas&#61; FigureCanvasTkAgg( f, master&#61; root) canvas. draw( ) canvas. get_tk_widget( ) . place( x&#61; 60 , y&#61; 100 ) b1 &#61; tk. Button( root, text&#61; &#39;训练&#39; , bg&#61; &#39;white&#39; , font&#61; ( &#39;Arial&#39; , 12 ) , width&#61; 12 , height&#61; 1 , command&#61; train) . place( x&#61; 168 , y&#61; 35 ) b2 &#61; tk. Button( root, text&#61; &#39;拍照&#39; , bg&#61; &#39;white&#39; , font&#61; ( &#39;Arial&#39; , 12 ) , width&#61; 12 , height&#61; 1 , command&#61; frame) . place( x&#61; 550 , y&#61; 35 ) b3 &#61; tk. Button( root, text&#61; &#39;测试&#39; , bg&#61; &#39;white&#39; , font&#61; ( &#39;Arial&#39; , 12 ) , width&#61; 12 , height&#61; 1 , command&#61; text) . place( x&#61; 960 , y&#61; 35 ) b4 &#61; tk. Button( root, text&#61; &#39;导入图片&#39; , bg&#61; &#39;white&#39; , font&#61; ( &#39;Arial&#39; , 12 ) , width&#61; 12 , height&#61; 1 , command&#61; select_pic) . place( x&#61; 680 , y&#61; 35 ) b5 &#61; tk. Button( root, text&#61; &#39;识别结果&#39; , font&#61; ( &#39;Arial&#39; , 12 ) , bg&#61; &#39;white&#39; ) . place( x&#61; 990 , y&#61; 400 ) canvas&#61; tk. Canvas( root, bg&#61; "white" , width&#61; 360 , height&#61; 360 ) canvas. place( x&#61; 500 , y&#61; 100 ) b6&#61; tk. Button( root, text&#61; "保存" , bg&#61; "white" , width&#61; 15 , height&#61; 2 , command&#61; buttonl) . place( x&#61; 620 , y&#61; 420 ) root. mainloop( )
最后附上界面