1、使用环境:
ubuntu18.04,4gpu,nvidia410.78,cuda9.0,cudnn7.3,python3.6
2、使用代码:
官方提供的ocr模型代码
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/ocr_recognition
3、将代码运行起来
为了方便看到运行的效果,我修改了参数,save_model_period,这样可以更快的保存数据,好知道运行是否有效
![](https://img.php1.cn/3cd4a/1eebe/cd5/2fdc212433a29829.png)
![](https://img.php1.cn/3cd4a/1eebe/cd5/7cccb7e4b6cb5cb8.webp?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RpYW5hX1o=,size_16,color_FFFFFF,t_70)
4、生成自己的数据
import random
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import os
from unit import segmentationpath_font='/home/zz/文字'
path_out='/media/zz/testtttt'
if not os.path.exists(path_out):os.mkdir(path_out)dicters='0123456789.'
CHARS='0123456789'
number=100
font_index=0
font_list=os.listdir(path_font)
font_list.sort()
f_ind=0
f_size&#61;0def get_word(length):global font_index, CHARSf &#61; &#39;&#39;for i in range(length):f &#61; f &#43; random.choice(CHARS)font_index &#61; font_index &#43; 1return fdef get_txt():txt&#61;&#39;&#39;f3&#61;random.randint(0,2) # 2/3的可能会出现.len_num&#61;random.randint(1,8)num&#61;get_word(len_num)txt&#61;txt&#43;numif f3>0 and len(num)>&#61;3:txt&#61;txt[:-2]&#43;&#39;.&#39;&#43;txt[-2:]return txtdef get_txt_test():global dictersreturn dictersdef get_bg(color, w,h):# w_l&#61;random.randint(1,500)# w_r&#61;random.randint(1,500)# h_t&#61;random.randint(1,50)# h_b&#61;random.randint(1,50)w_l&#61;50w_r&#61;50h_t&#61;20h_b&#61;20bg&#61;np.zeros((h&#43;h_b&#43;h_t,w&#43;w_l&#43;w_r),dtype&#61;&#39;uint8&#39;)bg&#61;bg&#43;color# bg&#61;Image.fromarray(bg)return bg,w_l,h_tdef oblique(bg,fc):bg&#61;np.array(bg)imgh,imgw&#61;bg.shapenew_bg&#61;np.zeros((imgh,imgw&#43;2*imgh),dtype&#61;&#39;uint8&#39;)new_bg[:,imgh:imgh&#43;imgw]&#61;bgstep&#61;random.randint(15,25)st&#61;random.randint(0,step)while st&#43;imgh<&#61;imgw&#43;2*imgh:pt1&#61;(st,0)pt2&#61;(st&#43;imgh,imgh)new_bg&#61;cv2.line(new_bg, pt1, pt2, fc, 1, 4)st&#61;st&#43;step# cv2.imshow(&#39;a&#39;,new_bg)# cv2.waitKey()bg&#61;new_bg[:,imgh:imgh&#43;imgw]# cv2.imshow(&#39;a&#39;,new_bg)# cv2.imshow(&#39;b&#39;,bg)# cv2.waitKey()bg&#61;Image.fromarray(bg)return bgdef interfere(bg,x,y,w,h,bc,fc):global f_sizeif f_size>&#61;50 and not random.randint(0,3):bg&#61;oblique(bg,fc)return bgdef gen_data(co):global f_ind,f_size# 确定颜色f_color&#61;random.randint(0,255)bg_color&#61;f_colorwhile abs(bg_color-f_color)<30:bg_color&#61;random.randint(0,255)# 字体f_ind &#61; f_ind % len(font_list)f_size&#61;random.randint(15,75)txt &#61; get_txt() # 文字内容font_text &#61; ImageFont.truetype(&#39;{}/{}&#39;.format(path_font, font_list[f_ind]), f_size)print(&#39;{}---{}&#39;.format(f_ind,font_list[f_ind]))background_bg,x1,y1 &#61; get_bg(bg_color, font_text.getsize(txt)[0], font_text.getsize(txt)[1])background_bg &#61; Image.fromarray(background_bg, mode&#61;"L")draw_txt &#61; ImageDraw.Draw(background_bg) # 确认输出文字的背景图片draw_txt.text((x1, y1), txt, fill&#61;(f_color), font&#61;font_text)background_bg&#61;interfere(background_bg,x1,y1,font_text.getsize(txt)[0], font_text.getsize(txt)[1],bg_color,f_color)txt &#61; txt.replace(&#39;.&#39;, &#39;&#43;&#39;)background_bg.save(&#39;{}/{:08d}_{}.jpg&#39;.format(path_out, co, txt))f_ind &#61; f_ind &#43; 1returnif __name__ &#61;&#61; &#39;__main__&#39;:for i in range(number):print(i)gen_data(i)
![](https://img.php1.cn/3cd4a/1eebe/cd5/b428d8f746fb8d47.webp?x-oss-process&#61;image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RpYW5hX1o&#61;,size_16,color_FFFFFF,t_70)
生成灰度图片&#xff0c;规则&#xff1a;
1、背景和文字的颜色差大于30
2、字的个数在1-8个
3、左右上下有一个随机的扩大范围
4、如果出现小数点&#xff0c;保留两位小数
5、干扰&#xff0c;斜线
5、预处理
包括&#xff0c;包括裁剪&#xff0c;二值化&#xff0c;统计max size&#xff0c;这里由于之后需要resize&#xff0c;所以统计的是max ration&#xff0c;就是w/h的最大值
二值化使用opencv提供的otsu方法
在处理过程中缩减周围可缩小范围
import numpy as np
import cv2
import os
import time
from unit import segmentationt1 &#61; 0
t2 &#61; 0
name&#61;&#39;&#39;def linkseg(segx,dis,img):nst&#61;0nen&#61;0nseg&#61;[]h&#61;min(img.shape[0]//10,5)for i in range(len(segx)):st,en&#61;segx[i]img_p&#61;img[:,st:en]y&#61;np.sum(img_p,axis&#61;1)y&#61;y-np.average(y)pt&#61;len(np.where(y!&#61;0)[0])if pt0.5:img&#61;255-img# cv2.imshow(&#39;binar&#39;, img)# cv2.waitKey()# 裁剪文字&#xff0c;将之裁剪成一个一个字h,w&#61;img.shapey&#61;np.sum(img,axis&#61;1)y&#61;y-min(y)segy&#61;segmentation(y,0)# 根据横向的空白&#xff0c;将裁剪下来的数据分成多条&#xff0c;认为每一条的纵向上面没有干扰# 生成数据&#xff0c;上下没有干扰&#xff0c;所以删除裁剪不正确的数据if len(segy)>1:cv2.imwrite(&#39;{}/{}&#39;.format(path_err,filename),img)returnfor st,en in segy:img_p&#61;img[st:en,:]x&#61;np.sum(img_p,axis&#61;0)x&#61;x-min(x)segx&#61;segmentation(x,0)# link segsegx&#61;linkseg(segx,5,img_p)# 根据纵向的空白&#xff0c;将裁剪下来的数据分成多块&#xff0c;认为每一块上一个数字for st,en in segx:img_n&#61;img_p[:,st:en]if count>1:cv2.imshow(&#39;a&#39;,img_n)cv2.waitKey()img_n &#61; 255 - img_ncv2.imwrite(&#39;{}/{}&#39;.format(path_out, filename), img_n)count &#61; count &#43; 1# cv2.imshow(&#39;c1&#39;, img_n)# cv2.waitKey()if __name__ &#61;&#61; &#39;__main__&#39;:# img&#61;cv2.imread(&#39;/home/zz/图片/桌面/test/lALPDgQ9qxS5d1jNAwDNBVY_1366_768(第 5 个复件).png&#39;)# binar(img)path&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit&#39;path_out&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit-cut&#39;path_err&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit-err&#39;if not os.path.exists(path_out):os.mkdir(path_out)if not os.path.exists(path_err):os.mkdir(path_err)img_list&#61;os.listdir(path)img_list.sort()for i in range(544,len(img_list)):filename&#61;img_list[i]name &#61; &#39;{:03d}&#39;.format(i)print(filename)img&#61;cv2.imread(&#39;{}/{}&#39;.format(path,filename))binar(img)print(t1)print(t2)
5、保存label&#xff0c;并resize图片&#xff0c;分别保存到train和test
import os
import cv2
import re# 文件名称
path_f&#61;&#39;/media/cj1/data/digit_pic_lab&#39;
dir_img&#61;&#39;digit-cut&#39;
dir_train&#61;&#39;train_images&#39;
list_train&#61;&#39;train_list&#39;
dir_test&#61;&#39;test_images&#39;
list_test&#61;&#39;test_list&#39;# 计数器
count&#61;0# 字典
dicters&#61;&#39;0123456789.-&#xffe5;&#39;# 生成数据集
img_list&#61;os.listdir(&#39;{}/{}&#39;.format(path_f,dir_img))
img_list.sort()def get_lab_num(lab):s&#61;[]for l in lab:if l&#61;&#61;&#39;&#43;&#39;:ind&#61;dicters.index(&#39;.&#39;)else:ind&#61;dicters.index(l)s.append(str(ind))return sfs_train&#61;open(&#39;{}/{}&#39;.format(path_f,list_train),&#39;w&#39;)
fs_test&#61;open(&#39;{}/{}&#39;.format(path_f,list_test),&#39;w&#39;)for filename in img_list:print(filename)lab&#61;filename.split(&#39;.&#39;)[0].split(&#39;_&#39;)[-1] # 取出标签内容try:num_list&#61;get_lab_num(lab) # 转为数字标签except:# 转化错误的话&#xff0c;就直接下一个continuetry:img&#61;cv2.imread(&#39;{}/{}/{}&#39;.format(path_f,dir_img,filename))h,w,c&#61;img.shapenw&#61;int(w*IMG_H/h)img&#61;cv2.resize(img,(nw,IMG_H))img&#61;cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)img_n[:,:nw]&#61;imgexcept:# 读取错误&#xff0c;就直接下一个continuename&#61;re.sub(&#39;\D&#39;,&#39;&#39;,lab) # 取一个只有数字的名字newfilename&#61;&#39;{}_{}.jpg&#39;.format(filename.split(&#39;_&#39;)[0],name)if count%100&#61;&#61;0:# 每100张训练&#xff0c;保存一张测试if not os.path.exists(&#39;{}/{}&#39;.format(path_f,dir_test)):os.mkdir(&#39;{}/{}&#39;.format(path_f,dir_test))fs_test.write(&#39;{} {} {} {}\n&#39;.format(w,h,newfilename,&#39;,&#39;.join(num_list)))cv2.imwrite(&#39;{}/{}/{}&#39;.format(path_f,dir_test,newfilename),img)else:if not os.path.exists(&#39;{}/{}&#39;.format(path_f,dir_train)):os.mkdir(&#39;{}/{}&#39;.format(path_f,dir_train))fs_train.write(&#39;{} {} {} {}\n&#39;.format(w, h, newfilename, &#39;,&#39;.join(num_list)))cv2.imwrite(&#39;{}/{}/{}&#39;.format(path_f, dir_train, newfilename),img)count&#61;count&#43;1
![](https://img.php1.cn/3cd4a/1eebe/cd5/8ad8f3bf8da691df.webp?x-oss-process&#61;image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RpYW5hX1o&#61;,size_16,color_FFFFFF,t_70)
在同一个目录下生成如上四个文件&#xff0c;然后文件夹中保存的是图片&#xff0c;test_list保存的是标签
![](https://img.php1.cn/3cd4a/1eebe/cd5/72fd2c126203a875.webp?x-oss-process&#61;image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RpYW5hX1o&#61;,size_16,color_FFFFFF,t_70)
![](https://img.php1.cn/3cd4a/1eebe/cd5/bcafc120671304eb.webp)
6、修改模型代码
因为我们把数据格式根模型读取的数据格式生成的一样&#xff0c;所以大部分不用修改&#xff0c;只需要data_reader里面一些内容就可以了
1、分类数和图片大小&#xff0c;根据自己实际需要修改
![](https://img.php1.cn/3cd4a/1eebe/cd5/02c379d60086f382.webp)
我的这里&#xff1a;
NUM_CLASSES &#61;10
DATA_SHAPE &#61; [1, 32,300]
2、文件读取路径
![](https://img.php1.cn/3cd4a/1eebe/cd5/72fd2c126203a875.webp)
这个直接写成自己的
data_dir&#61;&#39;zz/data&#39;
7、错误汇总
--------------------------------------
遇到了一个问题&#xff0c;模型是上面的我的数据的新模型&#xff0c;数据是从生成数据中截留的一点测试数据
在官方的infer中执行结果&#xff1a;
![](https://img.php1.cn/3cd4a/1eebe/cd5/d05d9dfd09a56332.webp)
但是我train的时候准确率是高达0.99的&#xff0c;所以我用train修改了一个可以输出测试结果的代码
![](https://img.php1.cn/3cd4a/1eebe/cd5/ed19db63ee478b98.png)
发现&#xff0c;问题出现在这一步indexes &#61; prune(np.array(result[0]).flatten(), 0, 1)
由于我的输出当中全部都是数字&#xff0c;所以生成的结果通过0和1缩短一下以后就不成数据了。
---------------------------------------
Enforce failed. Expected x_mat_dims[1] &#61;&#61; y_mat_dims[0], but received x_mat_dims[1]:768 !&#61; y_mat_dims[0]:512.
First matrix&#39;s width must be equal with second matrix&#39;s height. 768, 512 at [/paddle/paddle/fluid/operators/mul_op.cc:61]
错误原因&#xff1a;SHAPE的大小不对
我在的图片是32*300的&#xff0c;但是shape的大小设置成了48*500&#xff0c;然后就会报这个错误