热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于PaddleOCR2.4的天池街景字符编码识别Baseline

基于PaddleOCR2.4的天池街景字符编码识别Baseline-一、天池街景字符编码识别比赛比赛地址:https:tianchi.aliyun.comcompetitione

一、 天池街景字符编码识别比赛

比赛地址:https://tianchi.aliyun.com/competition/entrance/531795/information

1.数据来源

赛题来源自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到比赛数据集。

2.数据基本情况

该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。

enter image description here

3.数据集样本展示

4.字段表

所有的数据(训练集、验证集和测试集)的标注使用JSON格式,并使用文件名进行索引。如果一个文件中包括多个字符,则使用列表将字段进行组合。

Field Description
top 左上角坐标Y
height 字符高度
left 左上角坐标X
width 字符宽度
label 字符编码

注:数据集来源自SVHN,网页链接http://ufldl.stanford.edu/housenumbers/,并进行匿名处理和噪音处理,请各位选手使用比赛给定的数据集完成训练。

二、环境设置

PaddleOCR https://github.com/paddlepaddle/PaddleOCR 是一款全宇宙最强的用的OCR工具库,开箱即用,速度杠杠的。

# 从gitee上下载PaddleOCR代码,也可以从GitHub链接下载
!git clone https://gitee.com/paddlepaddle/PaddleOCR.git --depth=1
# 升级pip
!pip install -U pip 
# 安装依赖
%cd ~/PaddleOCR
%pip install -r requirements.txt
%cd ~/PaddleOCR/
!tree   -L 1
/home/aistudio/PaddleOCR
.
├── benchmark
├── configs
├── deploy
├── doc
├── __init__.py
├── LICENSE
├── MANIFEST.in
├── paddleocr.py
├── ppocr
├── PPOCRLabel
├── ppstructure
├── README_ch.md
├── README.md
├── requirements.txt
├── setup.py
├── StyleText
├── test_tipc
├── tools
└── train.sh

10 directories, 9 files

三、数据准备

据悉train数据集共10万张,解压,并划分出10000张作为测试集。

1.数据下载解压

#  解压缩数据集
%cd ~
!unzip -qoa data/data124095/street_code_rec_data.zip -d ~/data/
/home/aistudio
# 重命名文件夹
!mv data/街景编码识别 data/street_code_rec_data
# 解压test数据集
!unzip -qoa data/street_code_rec_data/mchar_test_a.zip -d data/street_code_rec_data/
# 解压eval据集
!unzip -qoa data/street_code_rec_data/mchar_val.zip -d data/street_code_rec_data/
# 解压train数据集
!unzip -qoa data/street_code_rec_data/mchar_train.zip -d data/street_code_rec_data/
# 使用命令查看训练数据文件夹下数据量是否是3张
!cd data/street_code_rec_data/mchar_train &&  ls -l | grep "^-" | wc -l
30000
# 使用命令查看test数据文件夹下数据量是否是4万张
!cd data/street_code_rec_data/mchar_test_a  &&  ls -l | grep "^-" | wc -l
40000
# 使用命令查看test数据文件夹下数据量是否是1万张
!cd data/street_code_rec_data/mchar_val &&  ls -l | grep "^-" | wc -l
10000
%cd data/street_code_rec_data
!rm *.zip
%cd ~
/home/aistudio/data/street_code_rec_data
/home/aistudio

2. 数据标签处理

import json
def trans(path):
    with open(path + '.json', 'r') as f:
        json_data = json.load(f)
        print(len(json_data))
        with open(path + '.csv', 'w') as ff:
            for item in json_data:
                label = json_data[item]['label']
                label = [str(x) for x in label]
                label = ''.join(label)
                ff.write(item + '\t' + label + '\n')
trans('data/street_code_rec_data/mchar_val')
trans('data/street_code_rec_data/mchar_train')
10000
30000

3. 数据查看

!head data/street_code_rec_data/mchar_val.csv
000000.png  5
000001.png  210
000002.png  6
000003.png  1
000004.png  9
000005.png  1
000006.png  183
000007.png  65
000008.png  144
000009.png  16
!head data/street_code_rec_data/mchar_train.csv
000000.png  19
000001.png  23
000002.png  25
000003.png  93
000004.png  31
000005.png  33
000006.png  28
000007.png  744
000008.png  128
000009.png  16
from PIL import Image

img=Image.open('data/street_code_rec_data/mchar_train/000000.png')
print(img.size)
img
(741, 350)

四、配置训练参数

以PaddleOCR/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml为基准进行配置

1.配置模型网络

使用CRNN算法,backbone是MobileNetV3,损失函数是CTCLoss

Architecture:
  model_type: rec
  algorithm: CRNN
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model_name: small
    small_stride: [1, 2, 2, 2]
  Neck:
    name: SequenceEncoder
    encoder_type: rnn
    hidden_size: 48
  Head:
    name: CTCHead
    fc_decay: 0.00001

2.配置数据

对Train.data_dir, Train.label_file_list, Eval.data_dir, Eval.label_file_list进行配置

Train:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data/street_code_rec_data/mchar_train
    label_file_list: ["/home/aistudio/data/street_code_rec_data/mchar_train.csv"]
...
...
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data/street_code_rec_data/mchar_val
    label_file_list: ["/home/aistudio/data/street_code_rec_data/mchar_val.csv"]

3. 显卡、评估设置

use_gpu、cal_metric_during_train分别是GPU、评估开关

Global:
  use_gpu: false             # true 使用GPU
  .....
  cal_metric_during_train: False   # true 打开评估

4. 多线程任务

Train.loader.num_workers:4
Eval.loader.num_workers: 4

5.完整配置

Global:
  use_gpu: True
  epoch_num: 500
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_en_number_lite
  save_epoch_step: 3
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [1000, 100]
  # if pretrained_model is saved in static mode, load_static_weights must set to True
  cal_metric_during_train: True
  pretrained_model: ./en_number_mobile_v2.0_rec_train/best_accuracy.pdparams
  checkpoints: 
  save_inference_dir:
  use_visualdl: False
  infer_img:
  # for data or label process
  character_dict_path: ppocr/utils/en_dict.txt
  max_text_length: 25
  infer_mode: False
  use_space_char: True


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
  regularizer:
    name: 'L2'
    factor: 0.00001

Architecture:
  model_type: rec
  algorithm: CRNN
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model_name: small
    small_stride: [1, 2, 2, 2]
  Neck:
    name: SequenceEncoder
    encoder_type: rnn
    hidden_size: 48
  Head:
    name: CTCHead
    fc_decay: 0.00001

Loss:
  name: CTCLoss

PostProcess:
  name: CTCLabelDecode

Metric:
  name: RecMetric
  main_indicator: acc

Train:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data/street_code_rec_data/mchar_train
    label_file_list: ["/home/aistudio/data/street_code_rec_data/mchar_train.csv"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - RecAug: 
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True
    batch_size_per_card: 256
    drop_last: True
    num_workers: 8

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /home/aistudio/data/street_code_rec_data/mchar_val
    label_file_list: ["/home/aistudio/data/street_code_rec_data/mchar_val.csv"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 256
    num_workers: 8

6.使用预训练模型

据悉使用预训练模型,训练速度更快!!!

PaddleOCR提供的可下载模型包括推理模型训练模型预训练模型slim模型,模型区别说明如下:

模型类型 模型格式 简介
推理模型 inference.pdmodel、inference.pdiparams 用于预测引擎推理,详情
训练模型、预训练模型 *.pdparams、*.pdopt、*.states 训练过程中保存的模型的参数、优化器状态和训练中间信息,多用于模型指标评估和恢复训练
slim模型 *.nb 经过飞桨模型压缩工具PaddleSlim压缩后的模型,适用于移动端/IoT端等端侧部署场景(需使用飞桨Paddle Lite部署)。

各个模型的关系如下面的示意图所示。

文本检测模型

英文识别模型
模型名称 模型简介 配置文件 推理模型大小 下载地址
en_number_mobile_slim_v2.0_rec slim裁剪量化版超轻量模型,支持英文、数字识别 rec_en_number_lite_train.yml 2.7M 推理模型 / 训练模型
en_number_mobile_v2.0_rec 原始超轻量模型,支持英文、数字识别 rec_en_number_lite_train.yml 2.6M 推理模型 / 训练模型

%cd ~/PaddleOCR/
# mobile模型

!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_train.tar
!tar -xf en_number_mobile_v2.0_rec_train.tar
/home/aistudio/PaddleOCR
--2022-01-02 00:10:41--  https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_train.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.229|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9123840 (8.7M) [application/x-tar]
Saving to: ‘en_number_mobile_v2.0_rec_train.tar’

en_number_mobile_v2 100%[===================>]   8.70M  8.63MB/s    in 1.0s    

2022-01-02 00:10:42 (8.63 MB/s) - ‘en_number_mobile_v2.0_rec_train.tar’ saved [9123840/9123840]

五、训练

%cd ~/PaddleOCR/
# mobile模型
!python tools/train.py -c ./configs/rec/multi_language/rec_en_number_lite_train.yml -o Global.checkpoints=./output/rec_en_number_lite/latest

1.选择合适的batch size

2.训练日志

2022/01/02 01:28:23] root INFO: save model in ./output/rec_en_number_lite/latest
[2022/01/02 01:28:23] root INFO: Initialize indexs of datasets:['/home/aistudio/data/street_code_rec_data/mchar_train.csv']
[2022/01/02 01:28:54] root INFO: epoch: [27/500], iter: 180, lr: 0.000986, loss: 1.043328, acc: 0.765624, norm_edit_dis: 0.863509, reader_cost: 2.26051 s, batch_cost: 2.59590 s, samples: 7168, ips: 276.12724
[2022/01/02 01:29:18] root INFO: epoch: [27/500], iter: 190, lr: 0.000986, loss: 1.056450, acc: 0.765624, norm_edit_dis: 0.864510, reader_cost: 1.18228 s, batch_cost: 1.65932 s, samples: 10240, ips: 617.12064
[2022/01/02 01:29:34] root INFO: epoch: [27/500], iter: 200, lr: 0.000985, loss: 1.069025, acc: 0.759277, norm_edit_dis: 0.860254, reader_cost: 0.74316 s, batch_cost: 1.15521 s, samples: 10240, ips: 886.42030
eval model:: 100%|██████████████████████████████| 10/10 [00:07<00:00,  2.12it/s]
[2022/01/02 01:29:42] root INFO: cur metric, acc: 0.6261999373800062, norm_edit_dis: 0.7362716930394972, fps: 4054.7339744968563
[2022/01/02 01:29:42] root INFO: save best model is to ./output/rec_en_number_lite/best_accuracy
[2022/01/02 01:29:42] root INFO: best metric, acc: 0.6261999373800062, start_epoch: 21, norm_edit_dis: 0.7362716930394972, fps: 4054.7339744968563, best_epoch: 27

3. visualdl可视化

  • 本地安装visualdl pip install visualdl
  • 下载日志至本地
  • 启动visualdl可视化 visualdl --logdir ./
  • 打开浏览器查看 http://localhost:8040/

六、模型评估

# GPU 评估, Global.checkpoints 为待测权重
%cd ~/PaddleOCR/
# mobile模型

!python  -m paddle.distributed.launch tools/eval.py -c ./configs/rec/multi_language/rec_en_number_lite_train.yml \
    -o Global.checkpoints=./output/rec_en_number_lite/best_accuracy.pdparams
/home/aistudio/PaddleOCR
-----------  Configuration Arguments -----------
backend: auto
elastic_server: None
force: False
gpus: None
heter_devices: 
heter_worker_num: None
heter_workers: 
host: None
http_port: None
ips: 127.0.0.1
job_id: None
log_dir: log
np: None
nproc_per_node: None
run_mode: None
scale: 0
server_num: None
servers: 
training_script: tools/eval.py
training_script_args: ['-c', './configs/rec/multi_language/rec_en_number_lite_train.yml', '-o', 'Global.checkpoints=./output/rec_en_number_lite/best_accuracy.pdparams']
worker_num: None
workers: 
------------------------------------------------
WARNING 2022-01-02 01:32:26,892 launch.py:423] Not found distinct arguments and compiled with cuda or xpu. Default use collective mode
launch train in GPU mode!
INFO 2022-01-02 01:32:26,894 launch_utils.py:528] Local start 1 processes. First process distributed environment info (Only For Debug): 
    +=======================================================================================+
    |                        Distributed Envs                      Value                    |
    +---------------------------------------------------------------------------------------+
    |                       PADDLE_TRAINER_ID                        0                      |
    |                 PADDLE_CURRENT_ENDPOINT                 127.0.0.1:33420               |
    |                     PADDLE_TRAINERS_NUM                        1                      |
    |                PADDLE_TRAINER_ENDPOINTS                 127.0.0.1:33420               |
    |                     PADDLE_RANK_IN_NODE                        0                      |
    |                 PADDLE_LOCAL_DEVICE_IDS                        0                      |
    |                 PADDLE_WORLD_DEVICE_IDS                        0                      |
    |                     FLAGS_selected_gpus                        0                      |
    |             FLAGS_selected_accelerators                        0                      |
    +=======================================================================================+

INFO 2022-01-02 01:32:26,894 launch_utils.py:532] details abouts PADDLE_TRAINER_ENDPOINTS can be found in log/endpoints.log, and detail running logs maybe found in log/workerlog.0
launch proc_id:1384 idx:0
[2022/01/02 01:32:28] root INFO: Architecture : 
[2022/01/02 01:32:28] root INFO:     Backbone : 
[2022/01/02 01:32:28] root INFO:         model_name : small
[2022/01/02 01:32:28] root INFO:         name : MobileNetV3
[2022/01/02 01:32:28] root INFO:         scale : 0.5
[2022/01/02 01:32:28] root INFO:         small_stride : [1, 2, 2, 2]
[2022/01/02 01:32:28] root INFO:     Head : 
[2022/01/02 01:32:28] root INFO:         fc_decay : 1e-05
[2022/01/02 01:32:28] root INFO:         name : CTCHead
[2022/01/02 01:32:28] root INFO:     Neck : 
[2022/01/02 01:32:28] root INFO:         encoder_type : rnn
[2022/01/02 01:32:28] root INFO:         hidden_size : 48
[2022/01/02 01:32:28] root INFO:         name : SequenceEncoder
[2022/01/02 01:32:28] root INFO:     Transform : None
[2022/01/02 01:32:28] root INFO:     algorithm : CRNN
[2022/01/02 01:32:28] root INFO:     model_type : rec
[2022/01/02 01:32:28] root INFO: Eval : 
[2022/01/02 01:32:28] root INFO:     dataset : 
[2022/01/02 01:32:28] root INFO:         data_dir : /home/aistudio/data/street_code_rec_data/mchar_val
[2022/01/02 01:32:28] root INFO:         label_file_list : ['/home/aistudio/data/street_code_rec_data/mchar_val.csv']
[2022/01/02 01:32:28] root INFO:         name : SimpleDataSet
[2022/01/02 01:32:28] root INFO:         transforms : 
[2022/01/02 01:32:28] root INFO:             DecodeImage : 
[2022/01/02 01:32:28] root INFO:                 channel_first : False
[2022/01/02 01:32:28] root INFO:                 img_mode : BGR
[2022/01/02 01:32:28] root INFO:             CTCLabelEncode : None
[2022/01/02 01:32:28] root INFO:             RecResizeImg : 
[2022/01/02 01:32:28] root INFO:                 image_shape : [3, 32, 320]
[2022/01/02 01:32:28] root INFO:             KeepKeys : 
[2022/01/02 01:32:28] root INFO:                 keep_keys : ['image', 'label', 'length']
[2022/01/02 01:32:28] root INFO:     loader : 
[2022/01/02 01:32:28] root INFO:         batch_size_per_card : 1024
[2022/01/02 01:32:28] root INFO:         drop_last : False
[2022/01/02 01:32:28] root INFO:         num_workers : 8
[2022/01/02 01:32:28] root INFO:         shuffle : False
[2022/01/02 01:32:28] root INFO: Global : 
[2022/01/02 01:32:28] root INFO:     cal_metric_during_train : True
[2022/01/02 01:32:28] root INFO:     character_dict_path : ppocr/utils/en_dict.txt
[2022/01/02 01:32:28] root INFO:     checkpoints : ./output/rec_en_number_lite/best_accuracy.pdparams
[2022/01/02 01:32:28] root INFO:     debug : False
[2022/01/02 01:32:28] root INFO:     distributed : False
[2022/01/02 01:32:28] root INFO:     epoch_num : 500
[2022/01/02 01:32:28] root INFO:     eval_batch_step : [100, 100]
[2022/01/02 01:32:28] root INFO:     infer_img : None
[2022/01/02 01:32:28] root INFO:     infer_mode : False
[2022/01/02 01:32:28] root INFO:     log_smooth_window : 20
[2022/01/02 01:32:28] root INFO:     max_text_length : 25
[2022/01/02 01:32:28] root INFO:     pretrained_model : ./en_number_mobile_v2.0_rec_train/best_accuracy.pdparams
[2022/01/02 01:32:28] root INFO:     print_batch_step : 10
[2022/01/02 01:32:28] root INFO:     save_epoch_step : 3
[2022/01/02 01:32:28] root INFO:     save_inference_dir : None
[2022/01/02 01:32:28] root INFO:     save_model_dir : ./output/rec_en_number_lite
[2022/01/02 01:32:28] root INFO:     use_gpu : True
[2022/01/02 01:32:28] root INFO:     use_space_char : True
[2022/01/02 01:32:28] root INFO:     use_visualdl : False
[2022/01/02 01:32:28] root INFO: Loss : 
[2022/01/02 01:32:28] root INFO:     name : CTCLoss
[2022/01/02 01:32:28] root INFO: Metric : 
[2022/01/02 01:32:28] root INFO:     main_indicator : acc
[2022/01/02 01:32:28] root INFO:     name : RecMetric
[2022/01/02 01:32:28] root INFO: Optimizer : 
[2022/01/02 01:32:28] root INFO:     beta1 : 0.9
[2022/01/02 01:32:28] root INFO:     beta2 : 0.999
[2022/01/02 01:32:28] root INFO:     lr : 
[2022/01/02 01:32:28] root INFO:         learning_rate : 0.001
[2022/01/02 01:32:28] root INFO:         name : Cosine
[2022/01/02 01:32:28] root INFO:     name : Adam
[2022/01/02 01:32:28] root INFO:     regularizer : 
[2022/01/02 01:32:28] root INFO:         factor : 1e-05
[2022/01/02 01:32:28] root INFO:         name : L2
[2022/01/02 01:32:28] root INFO: PostProcess : 
[2022/01/02 01:32:28] root INFO:     name : CTCLabelDecode
[2022/01/02 01:32:28] root INFO: Train : 
[2022/01/02 01:32:28] root INFO:     dataset : 
[2022/01/02 01:32:28] root INFO:         data_dir : /home/aistudio/data/street_code_rec_data/mchar_train
[2022/01/02 01:32:28] root INFO:         label_file_list : ['/home/aistudio/data/street_code_rec_data/mchar_train.csv']
[2022/01/02 01:32:28] root INFO:         name : SimpleDataSet
[2022/01/02 01:32:28] root INFO:         transforms : 
[2022/01/02 01:32:28] root INFO:             DecodeImage : 
[2022/01/02 01:32:28] root INFO:                 channel_first : False
[2022/01/02 01:32:28] root INFO:                 img_mode : BGR
[2022/01/02 01:32:28] root INFO:             RecAug : None
[2022/01/02 01:32:28] root INFO:             CTCLabelEncode : None
[2022/01/02 01:32:28] root INFO:             RecResizeImg : 
[2022/01/02 01:32:28] root INFO:                 image_shape : [3, 32, 320]
[2022/01/02 01:32:28] root INFO:             KeepKeys : 
[2022/01/02 01:32:28] root INFO:                 keep_keys : ['image', 'label', 'length']
[2022/01/02 01:32:28] root INFO:     loader : 
[2022/01/02 01:32:28] root INFO:         batch_size_per_card : 1024
[2022/01/02 01:32:28] root INFO:         drop_last : True
[2022/01/02 01:32:28] root INFO:         num_workers : 8
[2022/01/02 01:32:28] root INFO:         shuffle : True
[2022/01/02 01:32:28] root INFO: profiler_options : None
[2022/01/02 01:32:28] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
[2022/01/02 01:32:28] root INFO: Initialize indexs of datasets:['/home/aistudio/data/street_code_rec_data/mchar_val.csv']
W0102 01:32:28.580307  1384 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0102 01:32:28.584791  1384 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2022/01/02 01:32:33] root INFO: resume from ./output/rec_en_number_lite/best_accuracy
[2022/01/02 01:32:33] root INFO: metric in ckpt ***************
[2022/01/02 01:32:33] root INFO: acc:0.6261999373800062
[2022/01/02 01:32:33] root INFO: start_epoch:28
[2022/01/02 01:32:33] root INFO: norm_edit_dis:0.7362716930394972
[2022/01/02 01:32:33] root INFO: fps:4054.7339744968563
[2022/01/02 01:32:33] root INFO: best_epoch:27

eval model::   0%|          | 0/10 [00:00

七、结果预测

预测脚本使用预测训练好的模型,并将结果保存成txt格式,可以直接送到比赛提交入口测评,文件默认保存在output/rec/predicts_chinese_lite_v2.0.txt

1.提交内容与格式

本次比赛要求参赛选手必须提交使用深度学习平台飞桨(PaddlePaddle)训练的模型。参赛者要求以.txt 文本格式提交结果,其中每一行是图片名称和文字预测的结果,中间以 “\t” 作为分割符,示例如下:

new_name value
0.jpg 文本0

2. infer_rec.py修改


    with open(save_res_path, "w") as fout:
        # 添加列头
        fout.write('file_name' + "," + 'file_code' +'\n')
        for file in get_image_file_list(config['Global']['infer_img']):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)
            if config['Architecture']['algorithm'] == "SRN":
                encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
                gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
                gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
                gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)

                others = [
                    paddle.to_tensor(encoder_word_pos_list),
                    paddle.to_tensor(gsrm_word_pos_list),
                    paddle.to_tensor(gsrm_slf_attn_bias1_list),
                    paddle.to_tensor(gsrm_slf_attn_bias2_list)
                ]
            if config['Architecture']['algorithm'] == "SAR":
                valid_ratio = np.expand_dims(batch[-1], axis=0)
                img_metas = [paddle.to_tensor(valid_ratio)]

            images = np.expand_dims(batch[0], axis=0)
            images = paddle.to_tensor(images)
            if config['Architecture']['algorithm'] == "SRN":
                preds = model(images, others)
            elif config['Architecture']['algorithm'] == "SAR":
                preds = model(images, img_metas)
            else:
                preds = model(images)
            post_result = post_process_class(preds)
            info = None
            if isinstance(post_result, dict):
                rec_info = dict()
                for key in post_result:
                    if len(post_result[key][0]) >= 2:
                        rec_info[key] = {
                            "label": post_result[key][0][0],
                            "score": float(post_result[key][0][1]),
                        }
                info = json.dumps(rec_info)
            else:
                if len(post_result[0]) >= 2:
                    info = post_result[0][0] + "\t" + str(post_result[0][1])

            if info is not None:
                logger.info("\t result: {}".format(info))
                fout.write(file + "," +  post_result[0][0] +'\n')
    logger.info("success!")
%cd ~/PaddleOCR/
# mobile模型

!python tools/infer_rec.py -c configs/rec/multi_language/rec_en_number_lite_train.yml \
    -o Global.infer_img="/home/aistudio/data/street_code_rec_data/mchar_test_a" \
    Global.checkpoints=./output/rec_en_number_lite/best_accuracy.pdparams

预测日志

[2022/01/02 02:01:08] root INFO:     result: 2123   0.9544541
[2022/01/02 02:01:08] root INFO: infer_img: /home/aistudio/data/street_code_rec_data/mchar_test_a/039996.png
[2022/01/02 02:01:08] root INFO:     result: 341    0.8990403
[2022/01/02 02:01:08] root INFO: infer_img: /home/aistudio/data/street_code_rec_data/mchar_test_a/039997.png
[2022/01/02 02:01:08] root INFO:     result: 167    0.95185596
[2022/01/02 02:01:08] root INFO: infer_img: /home/aistudio/data/street_code_rec_data/mchar_test_a/039998.png
[2022/01/02 02:01:08] root INFO:     result: 235    0.9978804
[2022/01/02 02:01:08] root INFO: infer_img: /home/aistudio/data/street_code_rec_data/mchar_test_a/039999.png
[2022/01/02 02:01:08] root INFO:     result: 910    0.93325263
[2022/01/02 02:01:08] root INFO: success!
...
...

八、基于预测引擎的预测

1.模型大小限制

约束性条件1:模型总大小不超过10MB(以.pdmodel和.pdiparams文件非压缩状态磁盘占用空间之和为准);

2.解决办法

训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。实际上,此处的约束条件限制的是inference 模型的大小。inference 模型一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成,模型大小也会小一些。

# 静态模型导出
%cd ~/PaddleOCR/
# mobile模型

!python tools/export_model.py -c   configs/rec/multi_language/rec_en_number_lite_train.yml \
    -o Global.checkpoints=./output/rec_en_number_lite/best_accuracy.pdparams \
    Global.save_inference_dir=./inference/rec_inference/
/home/aistudio/PaddleOCR
W0102 02:06:39.026404  4766 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0102 02:06:39.030951  4766 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2022/01/02 02:06:43] root INFO: resume from ./output/rec_en_number_lite/best_accuracy
[2022/01/02 02:06:45] root INFO: inference model is saved to ./inference/rec_inference/inference
%cd ~/PaddleOCR/
!du -sh ./inference/rec_inference/
/home/aistudio/PaddleOCR
2.8M    ./inference/rec_inference/
  • 可以看到,当前训练使用的CRNN算法导出inference后,仅有2.8M。
  • 导出的inference模型也可以用来预测,预测逻辑如下代码所示。
# 使用导出静态模型预测
%cd ~/PaddleOCR/
!python3.7 tools/infer/predict_rec.py  --rec_model_dir=./inference/rec_inference/  --image_dir="/home/aistudio/data/street_code_rec_data/mchar_test_a"

预测日志

[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012500.png:('疗绚娇', 0.71012855)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012501.png:('绚诚', 0.9246478)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012502.png:('溜', 0.93994504)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012503.png:('诚溜', 0.95832443)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012504.png:('溜溜', 0.87103844)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012505.png:('贿', 0.34199885)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012506.png:('题', 0.9996681)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012507.png:('绚绚', 0.9908391)
[2022/01/02 02:08:37] root INFO: Predicts of /home/aistudio/data/street_code_rec_data/mchar_test_a/012508.png:('绚', 0.58176464)
...
...

九、提交

预测结果保存到配置文件指定的 output/rec/predicts_chinese_lite_v2.0.txt文件,可直接提交即可。

%cd ~
!head PaddleOCR/output/rec/predicts_rec.txt
/home/aistudio
file_name,file_code
/home/aistudio/data/street_code_rec_data/mchar_test_a/000000.png,59
/home/aistudio/data/street_code_rec_data/mchar_test_a/000001.png,290
/home/aistudio/data/street_code_rec_data/mchar_test_a/000002.png,113
/home/aistudio/data/street_code_rec_data/mchar_test_a/000003.png,97
/home/aistudio/data/street_code_rec_data/mchar_test_a/000004.png,63
/home/aistudio/data/street_code_rec_data/mchar_test_a/000005.png,39
/home/aistudio/data/street_code_rec_data/mchar_test_a/000006.png,126
/home/aistudio/data/street_code_rec_data/mchar_test_a/000007.png,1475
/home/aistudio/data/street_code_rec_data/mchar_test_a/000008.png,48

随便跑跑82分,大家可以再处理处理,把检测数据也用上,优化优化,多跑几轮,一定可以取得更好的成绩。

Ai Studio项目地址: 基于PaddleOCR2.4的天池街景字符编码识别Baseline


推荐阅读
author-avatar
8090互助联盟
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有