热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

世界人工智能大赛OCR赛题方案!

 Datawhale干货 作者:阿水,北京航空航天大学,Datawhale成员本文以世界人工智能创新大赛(AIWIN)手写体OCR识别竞赛为实践背景,给出了OCR实践的常见思路和流

 Datawhale干货 

作者:阿水,北京航空航天大学,Datawhale成员

本文以世界人工智能创新大赛(AIWIN)手写体 OCR 识别竞赛为实践背景,给出了OCR实践的常见思路和流程。本项目使用PaddlePaddle 2.0动态图实现的CRNN文字识别模型,全文代码及思路如下。后台回复 211112 可获取完整代码。

代码地址:https://aistudio.baidu.com/aistudio/projectdetail/2612313

《世界人工智能大赛OCR赛题方案!》

赛题背景

银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。

比赛地址:http://ailab.aiwin.org.cn/competitions/65

赛题任务

本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:

  • 输入:手写体图像切片数据集

  • 输出:对应的识别结果

《世界人工智能大赛OCR赛题方案!》

代码说明

本项目是PaddlePaddle 2.0动态图实现的CRNN文字识别模型,可支持长短不一的图片输入。CRNN是一种端到端的识别模式,不需要通过分割图片即可完成图片中全部的文字识别。CRNN的结构主要是CNN+RNN+CTC,它们分别的作用是:

  • 使用深度CNN,对输入图像提取特征,得到特征图;

  • 使用双向RNN(BLSTM)对特征序列进行预测,对序列中的每个特征向量进行学习,并输出预测标签(真实值)分布;

  • 使用 CTC Loss,把从循环层获取的一系列标签分布转换成最终的标签序列。《世界人工智能大赛OCR赛题方案!》

CRNN的结构如下,一张高为32的图片,宽度随意,一张图片经过多层卷积之后,高度就变成了1,经过paddle.squeeze()就去掉了高度,也就说从输入的图片BCHW经过卷积之后就成了BCW。然后把特征顺序从BCW改为WBC输入到RNN中,经过两次的RNN之后,模型的最终输入为(W, B, Class_num)。这恰好是CTCLoss函数的输入。《世界人工智能大赛OCR赛题方案!》

代码详情

使用环境:

  • PaddlePaddle 2.0.1

  • Python 3.7

!\rm -rf __MACOSX/ 测试集/ 训练集/ dataset/
!unzip 2021A_T1_Task1_数据集含训练集和测试集.zip > out.log

步骤1:生成额外的数据集

这一步可以跳过,如果想要获取更好的精度,可以自己添加。

import os
import time
from random import choice, randint, randrange
from PIL import Image, ImageDraw, ImageFont
# 验证码图片文字的字符集
characters = '拾伍佰正仟万捌贰整陆玖圆叁零角分肆柒亿壹元'
def selectedCharacters(length):
    result = ''.join(choice(characters) for _ in range(length))
    return result
def getColor():
    r = randint(0, 100)
    g = randint(0, 100)
    b = randint(0, 100)
    return (r, g, b)
def main(size=(200, 100), characterNumber=6, bgcolor=(255, 255, 255)):
    # 创建空白图像和绘图对象
    imageTemp = Image.new('RGB', size, bgcolor)
    draw01 = ImageDraw.Draw(imageTemp)
    # 生成并计算随机字符串的宽度和高度
    text = selectedCharacters(characterNumber)
    print(text)
    font = ImageFont.truetype(font_path, 40)
    width, height = draw01.textsize(text, font)
    if width + 2 * characterNumber > size[0] or height > size[1]:
        print('尺寸不合法')
        return
    # 绘制随机字符串中的字符
    startX = 0
    widthEachCharater = width // characterNumber
    for i in range(characterNumber):
        startX += widthEachCharater + 1
        position = (startX, (size[1] - height) // 2)
        draw01.text(xy=position, text=text[i], fOnt=font, fill=getColor())
    # 对像素位置进行微调,实现扭曲的效果
    imageFinal = Image.new('RGB', size, bgcolor)
    pixelsFinal = imageFinal.load()
    pixelsTemp = imageTemp.load()
    for y in range(size[1]):
        offset = randint(-1, 0)
        for x in range(size[0]):
            newx = x + offset
            if newx >= size[0]:
                newx = size[0] - 1
            elif newx < 0:
                newx = 0
            pixelsFinal[newx, y] = pixelsTemp[x, y]
    # 绘制随机颜色随机位置的干扰像素
    draw02 = ImageDraw.Draw(imageFinal)
    for i in range(int(size[0] * size[1] * 0.07)):
        draw02.point((randrange(0, size[0]), randrange(0, size[1])), fill=getColor())
    # 保存并显示图片
    imageFinal.save("dataset/images/%d_%s.jpg" % (round(time.time() * 1000), text))
def create_list():
    images = os.listdir('dataset/images')
    f_train = open('dataset/train_list.txt', 'w', encoding='utf-8')
    f_test = open('dataset/test_list.txt', 'w', encoding='utf-8')
    for i, image in enumerate(images):
        image_path = os.path.join('dataset/images', image).replace('\\', '/')
        label = image.split('.')[0].split('_')[1]
        if i % 100 == 0:
            f_test.write('%s\t%s\n' % (image_path, label))
        else:
            f_train.write('%s\t%s\n' % (image_path, label))
def creat_vocabulary():
    # 生成词汇表
    with open('dataset/train_list.txt', 'r', encoding='utf-8') as f:
        lines = f.readlines()
    v = set()
    for line in lines:
        _, label = line.replace('\n', '').split('\t')
        for c in label:
            v.add(c)
    vocabulary_path = 'dataset/vocabulary.txt'
    with open(vocabulary_path, 'w', encoding='utf-8') as f:
        f.write(' \n')
        for c in v:
            f.write(c + '\n')
if __name__ == '__main__':
    if not os.path.exists('dataset/images'):
        os.makedirs('dataset/images')

步骤2:安装依赖环境

!pip install Levenshtein
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.16.0)
Requirement already satisfied: rapidfuzz<1.9,>=1.8.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Levenshtein) (1.8.2)

步骤3:读取数据集

import glob, codecs, json, os
import numpy as np
date_jpgs = glob.glob('./训练集/date/images/*.jpg')
amount_jpgs = glob.glob('./训练集/amount/images/*.jpg')
lines = codecs.open('./训练集/date/gt.json', encoding='utf-8').readlines()
lines = ''.join(lines)
date_gt = json.loads(lines.replace(',\n}', '}'))
lines = codecs.open('./训练集/amount/gt.json', encoding='utf-8').readlines()
lines = ''.join(lines)
amount_gt = json.loads(lines.replace(',\n}', '}'))

data_path = date_jpgs + amount_jpgs
date_gt.update(amount_gt)
s = ''
for x in date_gt:
    s += date_gt[x]
char_list = list(set(list(s)))
char_list = char_list

步骤4:构造训练集

!mkdir dataset
!mkdir dataset/images
!cp 训练集/date/images/*.jpg dataset/images
!cp 训练集/amount/images/*.jpg dataset/images
mkdir: cannot create directory ‘dataset’: File exists
mkdir: cannot create directory ‘dataset/images’: File exists

with open('dataset/vocabulary.txt', 'w') as up:
    for x in char_list:
        up.write(x + '\n')
data_path = glob.glob('dataset/images/*.jpg')
np.random.shuffle(data_path)
with open('dataset/train_list.txt', 'w') as up:
    for x in data_path[:-100]:
        up.write(f'{x}\t{date_gt[os.path.basename(x)]}\n')
with open('dataset/test_list.txt', 'w') as up:
    for x in data_path[-100:]:
        up.write(f'{x}\t{date_gt[os.path.basename(x)]}\n')

执行上面程序生成的图片会放在dataset/images目录下,生成的训练数据列表和测试数据列表分别放在dataset/train_list.txtdataset/test_list.txt,最后还有个数据词汇表dataset/vocabulary.txt

数据列表的格式如下,左边是图片的路径,右边是文字标签。

dataset/images/1617420021182_c1dw.jpg c1dw
dataset/images/1617420021204_uvht.jpg uvht
dataset/images/1617420021227_hb30.jpg hb30
dataset/images/1617420021266_4nkx.jpg 4nkx
dataset/images/1617420021296_80nv.jpg 80nv

以下是数据集词汇表的格式,一行一个字符,第一行是空格,不代表任何字符。

f
s
2
7
3
n
d
w

训练自定义数据,参考上面的格式即可。

步骤5:训练模型

不管你是自定义数据集还是使用上面生成的数据,只要文件路径正确,即可开始进行训练。该训练支持长度不一的图片输入,但是每一个batch的数据的数据长度还是要一样的,这种情况下,笔者就用了collate_fn()函数,该函数可以把数据最长的找出来,然后把其他的数据补0,加到相同的长度。同时该函数还要输出它其中每条数据标签的实际长度,因为损失函数需要输入标签的实际长度。

  • 在训练过程中,程序会使用VisualDL记录训练结果

import paddle
import numpy as np
import os
from datetime import datetime
from utils.model import Model
from utils.decoder import ctc_greedy_decoder, label_to_string, cer
from paddle.io import DataLoader
from utils.data import collate_fn
from utils.data import CustomDataset
from visualdl import LogWriter
# 训练数据列表路径
train_data_list_path = 'dataset/train_list.txt'
# 测试数据列表路径
test_data_list_path = 'dataset/test_list.txt'
# 词汇表路径
voc_path = 'dataset/vocabulary.txt'
# 模型保存的路径
save_model = 'models/'
# 每一批数据大小
batch_size = 32
# 预训练模型路径
pretrained_model = None
# 训练轮数
num_epoch = 100
# 初始学习率大小
learning_rate = 1e-3
# 日志记录噐
writer = LogWriter(logdir='log')
def train():
    # 获取训练数据
    train_dataset = CustomDataset(train_data_list_path, voc_path, img_model.pdparams')))
        optimizer.set_state_dict(paddle.load(os.path.join(pretrained_model, 'optimizer.pdopt')))
    train_step = 0
    test_step = 0
    # 开始训练
    for epoch in range(num_epoch):
        for batch_id, (inputs, labels, input_lengths, label_lengths) in enumerate(train_loader()):
            out = model(inputs)
            # 计算损失
            input_lengths = paddle.full(shape=[batch_size], fill_value=out.shape[0], dtype='int64')
            loss = ctc_loss(out, labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            # 多卡训练只使用一个进程打印
            if batch_id % 100 == 0:
                print('[%s] Train epoch %d, batch %d, loss: %f' % (datetime.now(), epoch, batch_id, loss))
                writer.add_scalar('Train loss', loss, train_step)
                train_step += 1
        # 执行评估
        if epoch % 10 == 0:
            model.eval()
            cer = evaluate(model, test_loader, train_dataset.vocabulary)
            print('[%s] Test epoch %d, cer: %f' % (datetime.now(), epoch, cer))
            writer.add_scalar('Test cer', cer, test_step)
            test_step += 1
            model.train()
        # 记录学习率
        writer.add_scalar('Learning rate', scheduler.last_lr, epoch)
        scheduler.step()
        # 保存模型
        paddle.save(model.state_dict(), os.path.join(save_model, 'model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(save_model, 'optimizer.pdopt'))
# 评估模型
def evaluate(model, test_loader, vocabulary):
    cer_result = []
    for batch_id, (inputs, labels, _, _) in enumerate(test_loader()):
        # 执行识别
        outs = model(inputs)
        outs = paddle.transpose(outs, perm=[1, 0, 2])
        outs = paddle.nn.functional.softmax(outs)
        # 解码获取识别结果
        labelss = []
        out_strings = []
        for out in outs:
            out_string = ctc_greedy_decoder(out, vocabulary)
            out_strings.append(out_string)
        for i, label in enumerate(labels):
            label_str = label_to_string(label, vocabulary)
            labelss.append(label_str)
        for out_string, label in zip(*(out_strings, labelss)):
            # 计算字错率
            c = cer(out_string, label) / float(len(label))
            cer_result.append(c)
    cer_result = float(np.mean(cer_result))
    return cer_result
if __name__ == '__main__':
    train()

步骤6:模型预测

训练结束之后,使用保存的模型进行预测。通过修改image_path指定需要预测的图片路径,解码方法,笔者使用了一个最简单的贪心策略。

import os
from PIL import Image
import numpy as np
import paddle
from utils.model import Model
from utils.data import process
from utils.decoder import ctc_greedy_decoder
with open('dataset/vocabulary.txt', 'r', encoding='utf-8') as f:
    vocabulary = f.readlines()
vocabulary = [v.replace('\n', '') for v in vocabulary]
save_model = 'models/'
model = Model(vocabulary, image_model.pdparams')))
model.eval()
def infer(path):
    data = process(path, img_float32')
    # 执行识别
    out = model(data)
    out = paddle.transpose(out, perm=[1, 0, 2])
    out = paddle.nn.functional.softmax(out)[0]
    # 解码获取识别结果
    out_string = ctc_greedy_decoder(out, vocabulary)
    # print('预测结果:%s' % out_string)
    return out_string
if __name__ == '__main__':
    image_path = 'dataset/images/0_8bb194207a248698017a854d62c96104.jpg'
    display(Image.open(image_path))
    print(infer(image_path))

《世界人工智能大赛OCR赛题方案!》


贰零贰零贰壹

from tqdm import tqdm, tqdm_notebook
result_dict = {}
for path in tqdm(glob.glob('./测试集/date/images/*.jpg')):
    text = infer(path)
    result_dict[os.path.basename(path)] = {
        'result': text,
        'confidence': 0.9
    }
for path in tqdm(glob.glob('./测试集/amount/images/*.jpg')):
    text = infer(path)
    result_dict[os.path.basename(path)] = {
        'result': text,
        'confidence': 0.9
    }

with open('answer.json', 'w', encoding='utf-8') as up:
    json.dump(result_dict, up, ensure_ascii=False, indent=4)
!zip answer.json.zip answer.json
  adding: answer.json (deflated 85%)

《世界人工智能大赛OCR赛题方案!》

整理不易,三连


推荐阅读
  • 本文探讨了互联网服务提供商(ISP)如何可能篡改或插入用户请求的数据流,并提供了有效的技术手段来防止此类劫持行为,确保网络环境的安全与纯净。 ... [详细]
  • 本文回顾了作者在求职阿里和腾讯实习生过程中,从最初的迷茫到最后成功获得Offer的心路历程。文中不仅分享了个人的面试经历,还提供了宝贵的面试准备建议和技巧。 ... [详细]
  • Python3爬虫入门:pyspider的基本使用[python爬虫入门]
    Python学习网有大量免费的Python入门教程,欢迎大家来学习。本文主要通过爬取去哪儿网的旅游攻略来给大家介绍pyspid ... [详细]
  • 本文详细介绍如何在SSM(Spring + Spring MVC + MyBatis)框架中实现分页功能。包括分页的基本概念、数据准备、前端分页栏的设计与实现、后端分页逻辑的编写以及最终的测试步骤。 ... [详细]
  • 【MySQL】frm文件解析
    官网说明:http:dev.mysql.comdocinternalsenfrm-file-format.htmlfrm是MySQL表结构定义文件,通常frm文件是不会损坏的,但是如果 ... [详细]
  • 解析Java虚拟机HotSpot中的GC算法实现
    本文探讨了Java虚拟机(JVM)中HotSpot实现的垃圾回收(GC)算法,重点介绍了根节点枚举、安全点及安全区域的概念和技术细节,以及这些机制如何影响GC的效率和准确性。 ... [详细]
  • 数据输入验证与控件绑定方法
    本文提供了多种数据输入验证函数及控件绑定方法的实现代码,包括电话号码、数字、传真、邮政编码、电子邮件和网址的验证,以及报表绑定和自动编号等功能。 ... [详细]
  • 本文详细介绍了如何在 Ubuntu 14.04 系统上搭建仅使用 CPU 的 Caffe 深度学习框架,包括环境准备、依赖安装及编译过程。 ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • MySQL 函数调用性能优化策略与实践
    MySQL函数调用的性能优化是提升数据库整体效率的关键。本文探讨了MySQL中函数的确定性和不确定性分类,以及如何通过优化这些函数调用来提高查询性能。确定性函数在给定相同输入时始终返回相同的结果,而非确定性函数则可能因环境或时间等因素而返回不同的结果。文章详细介绍了识别和优化非确定性函数的方法,以减少对数据库性能的影响,并提供了实际应用中的案例分析。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 本文探讨了利用Python编程语言开发自动化脚本来实现文件的全量和增量备份方法。通过详细分析不同备份策略的特点,文章介绍了如何使用Python标准库中的os和shutil模块来高效地管理和执行备份任务。此外,还提供了示例代码和最佳实践,帮助读者快速掌握自动化备份技术,确保数据的安全性和完整性。 ... [详细]
  • 基于Python PaddleSpeech实现语音文字处理
    基于Python PaddleSpeech实现语音文字处理-目录前言环境安装项目验证tts语音合成asr语音识别标点恢复总结前言这段时间一直在研究飞浆平台,最近试了试PaddleS ... [详细]
  • mapreduce数据去重的实现方法
    本文介绍了利用mapreduce实现数据去重的方法,同时还介绍了人工智能AI领域中常用的框架和工具,包括Keras、PyTorch、MXNet、TensorFlow和PaddlePaddle,并提供了深度学习实战的代码下载链接。 ... [详细]
  • 老电影和图片变清晰的秘密!分辨率提升400%的AI算法
    老电影和图片变清晰的秘密!分辨率提升400%的AI算法-如上图,从100x133pix→400x532pix,除了肉眼可见的清晰,拥有可以将分辨率提升400%的技术到底意味着什么 ... [详细]
author-avatar
泡沫茱_617
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有