热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

paddlepaddle使用笔记——使用自己的数据训练ocr模型

1、使用环境:ubuntu18.04,4gpu,nvidia410.78,cuda9.0,cudnn7.3&

1、使用环境:

ubuntu18.04,4gpu,nvidia410.78,cuda9.0,cudnn7.3,python3.6

2、使用代码:

官方提供的ocr模型代码

https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/ocr_recognition

3、将代码运行起来

为了方便看到运行的效果,我修改了参数,save_model_period,这样可以更快的保存数据,好知道运行是否有效

4、生成自己的数据

import random
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import os
from unit import segmentationpath_font='/home/zz/文字'
path_out='/media/zz/testtttt'
if not os.path.exists(path_out):os.mkdir(path_out)dicters='0123456789.'
CHARS='0123456789'
number=100
font_index=0
font_list=os.listdir(path_font)
font_list.sort()
f_ind=0
f_size&#61;0def get_word(length):global font_index, CHARSf &#61; &#39;&#39;for i in range(length):f &#61; f &#43; random.choice(CHARS)font_index &#61; font_index &#43; 1return fdef get_txt():txt&#61;&#39;&#39;f3&#61;random.randint(0,2) # 2/3的可能会出现.len_num&#61;random.randint(1,8)num&#61;get_word(len_num)txt&#61;txt&#43;numif f3>0 and len(num)>&#61;3:txt&#61;txt[:-2]&#43;&#39;.&#39;&#43;txt[-2:]return txtdef get_txt_test():global dictersreturn dictersdef get_bg(color, w,h):# w_l&#61;random.randint(1,500)# w_r&#61;random.randint(1,500)# h_t&#61;random.randint(1,50)# h_b&#61;random.randint(1,50)w_l&#61;50w_r&#61;50h_t&#61;20h_b&#61;20bg&#61;np.zeros((h&#43;h_b&#43;h_t,w&#43;w_l&#43;w_r),dtype&#61;&#39;uint8&#39;)bg&#61;bg&#43;color# bg&#61;Image.fromarray(bg)return bg,w_l,h_tdef oblique(bg,fc):bg&#61;np.array(bg)imgh,imgw&#61;bg.shapenew_bg&#61;np.zeros((imgh,imgw&#43;2*imgh),dtype&#61;&#39;uint8&#39;)new_bg[:,imgh:imgh&#43;imgw]&#61;bgstep&#61;random.randint(15,25)st&#61;random.randint(0,step)while st&#43;imgh<&#61;imgw&#43;2*imgh:pt1&#61;(st,0)pt2&#61;(st&#43;imgh,imgh)new_bg&#61;cv2.line(new_bg, pt1, pt2, fc, 1, 4)st&#61;st&#43;step# cv2.imshow(&#39;a&#39;,new_bg)# cv2.waitKey()bg&#61;new_bg[:,imgh:imgh&#43;imgw]# cv2.imshow(&#39;a&#39;,new_bg)# cv2.imshow(&#39;b&#39;,bg)# cv2.waitKey()bg&#61;Image.fromarray(bg)return bgdef interfere(bg,x,y,w,h,bc,fc):global f_sizeif f_size>&#61;50 and not random.randint(0,3):bg&#61;oblique(bg,fc)return bgdef gen_data(co):global f_ind,f_size# 确定颜色f_color&#61;random.randint(0,255)bg_color&#61;f_colorwhile abs(bg_color-f_color)<30:bg_color&#61;random.randint(0,255)# 字体f_ind &#61; f_ind % len(font_list)f_size&#61;random.randint(15,75)txt &#61; get_txt() # 文字内容font_text &#61; ImageFont.truetype(&#39;{}/{}&#39;.format(path_font, font_list[f_ind]), f_size)print(&#39;{}---{}&#39;.format(f_ind,font_list[f_ind]))background_bg,x1,y1 &#61; get_bg(bg_color, font_text.getsize(txt)[0], font_text.getsize(txt)[1])background_bg &#61; Image.fromarray(background_bg, mode&#61;"L")draw_txt &#61; ImageDraw.Draw(background_bg) # 确认输出文字的背景图片draw_txt.text((x1, y1), txt, fill&#61;(f_color), font&#61;font_text)background_bg&#61;interfere(background_bg,x1,y1,font_text.getsize(txt)[0], font_text.getsize(txt)[1],bg_color,f_color)txt &#61; txt.replace(&#39;.&#39;, &#39;&#43;&#39;)background_bg.save(&#39;{}/{:08d}_{}.jpg&#39;.format(path_out, co, txt))f_ind &#61; f_ind &#43; 1returnif __name__ &#61;&#61; &#39;__main__&#39;:for i in range(number):print(i)gen_data(i)

生成灰度图片&#xff0c;规则&#xff1a;

1、背景和文字的颜色差大于30

2、字的个数在1-8个

3、左右上下有一个随机的扩大范围

4、如果出现小数点&#xff0c;保留两位小数

5、干扰&#xff0c;斜线

 

5、预处理

包括&#xff0c;包括裁剪&#xff0c;二值化&#xff0c;统计max size&#xff0c;这里由于之后需要resize&#xff0c;所以统计的是max ration&#xff0c;就是w/h的最大值

二值化使用opencv提供的otsu方法

在处理过程中缩减周围可缩小范围


import numpy as np
import cv2
import os
import time
from unit import segmentationt1 &#61; 0
t2 &#61; 0
name&#61;&#39;&#39;def linkseg(segx,dis,img):nst&#61;0nen&#61;0nseg&#61;[]h&#61;min(img.shape[0]//10,5)for i in range(len(segx)):st,en&#61;segx[i]img_p&#61;img[:,st:en]y&#61;np.sum(img_p,axis&#61;1)y&#61;y-np.average(y)pt&#61;len(np.where(y!&#61;0)[0])if pt0.5:img&#61;255-img# cv2.imshow(&#39;binar&#39;, img)# cv2.waitKey()# 裁剪文字&#xff0c;将之裁剪成一个一个字h,w&#61;img.shapey&#61;np.sum(img,axis&#61;1)y&#61;y-min(y)segy&#61;segmentation(y,0)# 根据横向的空白&#xff0c;将裁剪下来的数据分成多条&#xff0c;认为每一条的纵向上面没有干扰# 生成数据&#xff0c;上下没有干扰&#xff0c;所以删除裁剪不正确的数据if len(segy)>1:cv2.imwrite(&#39;{}/{}&#39;.format(path_err,filename),img)returnfor st,en in segy:img_p&#61;img[st:en,:]x&#61;np.sum(img_p,axis&#61;0)x&#61;x-min(x)segx&#61;segmentation(x,0)# link segsegx&#61;linkseg(segx,5,img_p)# 根据纵向的空白&#xff0c;将裁剪下来的数据分成多块&#xff0c;认为每一块上一个数字for st,en in segx:img_n&#61;img_p[:,st:en]if count>1:cv2.imshow(&#39;a&#39;,img_n)cv2.waitKey()img_n &#61; 255 - img_ncv2.imwrite(&#39;{}/{}&#39;.format(path_out, filename), img_n)count &#61; count &#43; 1# cv2.imshow(&#39;c1&#39;, img_n)# cv2.waitKey()if __name__ &#61;&#61; &#39;__main__&#39;:# img&#61;cv2.imread(&#39;/home/zz/图片/桌面/test/lALPDgQ9qxS5d1jNAwDNBVY_1366_768(第 5 个复件).png&#39;)# binar(img)path&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit&#39;path_out&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit-cut&#39;path_err&#61;&#39;/media/zz/AE1AD9D91AD99F21/book-zz/digit-err&#39;if not os.path.exists(path_out):os.mkdir(path_out)if not os.path.exists(path_err):os.mkdir(path_err)img_list&#61;os.listdir(path)img_list.sort()for i in range(544,len(img_list)):filename&#61;img_list[i]name &#61; &#39;{:03d}&#39;.format(i)print(filename)img&#61;cv2.imread(&#39;{}/{}&#39;.format(path,filename))binar(img)print(t1)print(t2)

 

 

5、保存label&#xff0c;并resize图片&#xff0c;分别保存到train和test

import os
import cv2
import re# 文件名称
path_f&#61;&#39;/media/cj1/data/digit_pic_lab&#39;
dir_img&#61;&#39;digit-cut&#39;
dir_train&#61;&#39;train_images&#39;
list_train&#61;&#39;train_list&#39;
dir_test&#61;&#39;test_images&#39;
list_test&#61;&#39;test_list&#39;# 计数器
count&#61;0# 字典
dicters&#61;&#39;0123456789.-&#xffe5;&#39;# 生成数据集
img_list&#61;os.listdir(&#39;{}/{}&#39;.format(path_f,dir_img))
img_list.sort()def get_lab_num(lab):s&#61;[]for l in lab:if l&#61;&#61;&#39;&#43;&#39;:ind&#61;dicters.index(&#39;.&#39;)else:ind&#61;dicters.index(l)s.append(str(ind))return sfs_train&#61;open(&#39;{}/{}&#39;.format(path_f,list_train),&#39;w&#39;)
fs_test&#61;open(&#39;{}/{}&#39;.format(path_f,list_test),&#39;w&#39;)for filename in img_list:print(filename)lab&#61;filename.split(&#39;.&#39;)[0].split(&#39;_&#39;)[-1] # 取出标签内容try:num_list&#61;get_lab_num(lab) # 转为数字标签except:# 转化错误的话&#xff0c;就直接下一个continuetry:img&#61;cv2.imread(&#39;{}/{}/{}&#39;.format(path_f,dir_img,filename))h,w,c&#61;img.shapenw&#61;int(w*IMG_H/h)img&#61;cv2.resize(img,(nw,IMG_H))img&#61;cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)img_n[:,:nw]&#61;imgexcept:# 读取错误&#xff0c;就直接下一个continuename&#61;re.sub(&#39;\D&#39;,&#39;&#39;,lab) # 取一个只有数字的名字newfilename&#61;&#39;{}_{}.jpg&#39;.format(filename.split(&#39;_&#39;)[0],name)if count%100&#61;&#61;0:# 每100张训练&#xff0c;保存一张测试if not os.path.exists(&#39;{}/{}&#39;.format(path_f,dir_test)):os.mkdir(&#39;{}/{}&#39;.format(path_f,dir_test))fs_test.write(&#39;{} {} {} {}\n&#39;.format(w,h,newfilename,&#39;,&#39;.join(num_list)))cv2.imwrite(&#39;{}/{}/{}&#39;.format(path_f,dir_test,newfilename),img)else:if not os.path.exists(&#39;{}/{}&#39;.format(path_f,dir_train)):os.mkdir(&#39;{}/{}&#39;.format(path_f,dir_train))fs_train.write(&#39;{} {} {} {}\n&#39;.format(w, h, newfilename, &#39;,&#39;.join(num_list)))cv2.imwrite(&#39;{}/{}/{}&#39;.format(path_f, dir_train, newfilename),img)count&#61;count&#43;1

在同一个目录下生成如上四个文件&#xff0c;然后文件夹中保存的是图片&#xff0c;test_list保存的是标签

6、修改模型代码

因为我们把数据格式根模型读取的数据格式生成的一样&#xff0c;所以大部分不用修改&#xff0c;只需要data_reader里面一些内容就可以了

1、分类数和图片大小&#xff0c;根据自己实际需要修改

我的这里&#xff1a;

NUM_CLASSES &#61;10
DATA_SHAPE &#61; [1, 32,300]

2、文件读取路径

这个直接写成自己的

data_dir&#61;&#39;zz/data&#39;

 

7、错误汇总

--------------------------------------

遇到了一个问题&#xff0c;模型是上面的我的数据的新模型&#xff0c;数据是从生成数据中截留的一点测试数据

在官方的infer中执行结果&#xff1a;

但是我train的时候准确率是高达0.99的&#xff0c;所以我用train修改了一个可以输出测试结果的代码

发现&#xff0c;问题出现在这一步indexes &#61; prune(np.array(result[0]).flatten(), 0, 1)

由于我的输出当中全部都是数字&#xff0c;所以生成的结果通过0和1缩短一下以后就不成数据了。

---------------------------------------

Enforce failed. Expected x_mat_dims[1] &#61;&#61; y_mat_dims[0], but received x_mat_dims[1]:768 !&#61; y_mat_dims[0]:512.
First matrix&#39;s width must be equal with second matrix&#39;s height. 768, 512 at [/paddle/paddle/fluid/operators/mul_op.cc:61]

错误原因&#xff1a;SHAPE的大小不对

我在的图片是32*300的&#xff0c;但是shape的大小设置成了48*500&#xff0c;然后就会报这个错误

 

 

 

 

 

 

 

 

 


推荐阅读
  • 本文探讨了为何相同的HTTP请求在两台不同操作系统(Windows与Ubuntu)的机器上会分别返回200 OK和429 Too Many Requests的状态码。我们将分析代码、环境差异及可能的影响因素。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • 在编译BSP包过程中,遇到了一个与 'gets' 函数相关的编译错误。该问题通常发生在较新的编译环境中,由于 'gets' 函数已被弃用并视为安全漏洞。本文将详细介绍如何通过修改源代码和配置文件来解决这一问题。 ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 深入理解Tornado模板系统
    本文详细介绍了Tornado框架中模板系统的使用方法。Tornado自带的轻量级、高效且灵活的模板语言位于tornado.template模块,支持嵌入Python代码片段,帮助开发者快速构建动态网页。 ... [详细]
  • 本文详细介绍了如何使用Python编写爬虫程序,从豆瓣电影Top250页面抓取电影信息。文章涵盖了从基础的网页请求到处理反爬虫机制,再到多页数据抓取的全过程,并提供了完整的代码示例。 ... [详细]
  • 在Ubuntu 16.04 LTS上配置Qt Creator开发环境
    本文详细介绍了如何在Ubuntu 16.04 LTS系统中安装和配置Qt Creator,涵盖了从下载到安装的全过程,并提供了常见问题的解决方案。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 自己用过的一些比较有用的css3新属性【HTML】
    web前端|html教程自己用过的一些比较用的css3新属性web前端-html教程css3刚推出不久,虽然大多数的css3属性在很多流行的浏览器中不支持,但我个人觉得还是要尽量开 ... [详细]
  • 选择适合生产环境的Docker存储驱动
    本文旨在探讨如何在生产环境中选择合适的Docker存储驱动,并详细介绍不同Linux发行版下的配置方法。通过参考官方文档和兼容性矩阵,提供实用的操作指南。 ... [详细]
  • 在创建新的Android项目时,您可能会遇到aapt错误,提示无法打开libstdc++.so.6共享对象文件。本文将探讨该问题的原因及解决方案。 ... [详细]
  • 本文介绍如何在Spring Boot项目中集成Redis,并通过具体案例展示其配置和使用方法。包括添加依赖、配置连接信息、自定义序列化方式以及实现仓储接口。 ... [详细]
  • 在PHP后端开发中遇到一个难题:通过第三方类文件发送短信功能返回的JSON字符串无法解析。本文将探讨可能的原因并提供解决方案。 ... [详细]
author-avatar
CJFONe
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有