热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于attention机制实现CRNNOCR文字识别

定义网络结构实现BahdanauAttention,其中socre的实现方法为perceptron形式classBahdanauAttention(tf.keras

定义网络结构

实现 BahdanauAttention,其中socre的实现方法为 perceptron 形式

class BahdanauAttention(tf.keras.Model):def __init__(self, units):super(BahdanauAttention, self).__init__()self.W1 = tf.keras.layers.Dense(units)self.W2 = tf.keras.layers.Dense(units)self.V = tf.keras.layers.Dense(1)def call(self, features, hidden):# feature 为encoder 生成的source编码矩阵 , hidden为 i-1 时刻的隐元状态hidden_with_time_axis = tf.expand_dims(hidden, 1)# score shape == (batch_size, output length, hidden_size)score = self.V(tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis)))# we get 1 at the last axis because we are applying score to self.Vattention_weights = tf.nn.softmax(score, axis=1)# context_vector shape after sum == (batch_size, hidden_size)context_vector = attention_weights * featurescontext_vector = tf.reduce_sum(context_vector, axis=1)return context_vector, attention_weights

定义GRU单元

def gru(units):if tf.test.is_gpu_available():return tf.keras.layers.CuDNNGRU(units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')else:return tf.keras.layers.GRU(units,return_sequences=True,return_state=True,recurrent_activation='sigmoid',recurrent_initializer='glorot_uniform')

使用CRNN feature 提取层 和 单层GRU生成编码器Encoder

class Encoder(tf.keras.Model):"""enc_units: encoder 隐元数量batch_sz: batch size"""def __init__(self, enc_units, batch_sz):super(Encoder, self).__init__()self.batch_sz = batch_szself.enc_units = enc_unitsself.cnn = tf.keras.Sequential([tf.keras.layers.Conv2D(64, [3, 3], padding="same", activation='relu'),tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),tf.keras.layers.Conv2D(128, [3, 3], padding="same", activation='relu'),tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),tf.keras.layers.Conv2D(256, [3, 3], padding="same", activation='relu'),tf.keras.layers.Conv2D(256, [3, 3], padding="same", activation='relu'),tf.keras.layers.MaxPool2D(pool_size=[2, 1], strides=[2, 1]),tf.keras.layers.Conv2D(512, [3, 3], padding="same", activation='relu'),tf.keras.layers.BatchNormalization(),tf.keras.layers.Conv2D(512, [3, 3], padding="same", activation='relu'),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPool2D(pool_size=[2, 1], strides=[2, 1]),tf.keras.layers.Conv2D(512, [2, 2], strides=[2, 1], padding="same", activation='relu'),tf.keras.layers.Reshape((25, 512))])self.gru = gru(self.enc_units)def call(self, x):x = self.cnn(x)output, state = self.gru(x)return output, statedef initialize_hidden_state(self):return tf.zeros((self.batch_sz, self.enc_units))

定义 attention 机制和 GRU 单元的解码器 Decoder

class Decoder(tf.keras.Model):def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):super(Decoder, self).__init__()self.batch_sz = batch_szself.dec_units = dec_unitsself.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)self.gru = gru(self.dec_units)self.fc = tf.keras.layers.Dense(vocab_size)self.attention = BahdanauAttention(self.dec_units)def call(self, x, hidden, enc_output):context_vector, attention_weights = self.attention(enc_output, hidden)# x shape after passing through embedding == (batch_size, 1, embedding_dim)x1 = self.embedding(x)# x shape after cOncatenation== (batch_size, 1, embedding_dim + hidden_size)x2 = tf.concat([tf.expand_dims(context_vector, 1), x1], axis=-1)# passing the concatenated vector to the GRUoutput, state = self.gru(x2)# output shape == (batch_size * 1, hidden_size)output = tf.reshape(output, (-1, output.shape[2]))# output shape == (batch_size * 1, vocab)x = self.fc(output)return x, state, attention_weights

准备数据

数据集采用mjsynth.tar.gz,这个数据集有些问题,某些样本大小写未分开标注,某些样本颜色梯度不够,可以先训练一个模型后对数据集做筛选,然后再fine tuen.

定义字典

# 将每个词汇映射为一个数字
class LanguageIndex():def __init__(self):self.word2idx = {}self.idx2word = {}self.vocab = cfg.CHAR_VECTORself.create_index()def create_index(self):self.word2idx['

'] = 0self.word2idx[''] = 1self.word2idx[''] = 2self.word2idx[''] = 3for index, word in enumerate(self.vocab):self.word2idx[word] = index + 4for word, index in self.word2idx.items():self.idx2word[index] = word

处理 label 为 .. 格式

root = "../mnt/ramdisk/max/90kDICT32px"def create_dataset_from_file(root, file_path):with open(file_path, "r") as f:readlines = f.readlines()img_paths = []for img_name in tqdm(readlines, desc="read dir:"):img_name = img_name.rstrip().strip()img_path = root + "/" + img_nameif osp.exists(img_path):img_paths.append(img_path)img_paths = img_paths[:1000000]labels = [img_path.split("/")[-1].split("_")[-2] for img_path in tqdm(img_paths, desc="generator label:")]return img_paths, labelsdef preprocess_label(label):label = label.rstrip().strip()w = ' 'for i in label:w += i + ' 'w += ' 'return wdef load_dataset(root):img_paths_tensor, labels = create_dataset_from_file(root, root + "/annotation_train.txt")labels = [label for label in labels]processed_labels = [preprocess_label(label) for label in tqdm(labels, desc="process label:")]label_lang = LanguageIndex(label for label in processed_labels)labels_tensor = [[label_lang.word2idx[s] for s in label.split(' ')] for label in processed_labels]label_max_len = max_length(labels_tensor)labels_tensor = tf.keras.preprocessing.sequence.pad_sequences(labels_tensor, maxlen=label_max_len, padding='post')return img_paths_tensor, labels_tensor, labels, label_lang, label_max_len

构建数据 dataset

def process_img(img_path):imread = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)imread = resize_image(imread, 100, 32)imread = np.expand_dims(imread, axis=-1)imread = np.array(imread, np.float32)return imreaddef resize_image(image, out_width, out_height):"""Resize an image to the "good" input size"""im_arr = imageh, w = np.shape(im_arr)[:2]ratio = out_height / him_arr_resized = cv2.resize(im_arr, (int(w * ratio), out_height))re_h, re_w = np.shape(im_arr_resized)[:2]if re_w >= out_width:final_arr = cv2.resize(im_arr, (out_width, out_height))else:final_arr = np.ones((out_height, out_width), dtype=np.uint8) * 255final_arr[:, 0:np.shape(im_arr_resized)[1]] = im_arr_resizedreturn final_arrimg_paths_tensor, labels_tensor, labels, label_lang, label_max_len = load_dataset(root)BATCH_SIZE = cfg.TRAIN_BATCH_SIZE
N_BATCH = len(img_paths_tensor) // BATCH_SIZE
embedding_dim = cfg.EMBEDDING_DIM
units = cfg.UNITSvocab_size = len(label_lang.word2idx)def map_func(img_path_tensor, label_tensor, label):imread = cv2.imread(img_path_tensor.decode('utf-8'), cv2.IMREAD_GRAYSCALE)imread = resize_image(imread, 100, 32)imread = np.expand_dims(imread, axis=-1)imread = np.array(imread, np.float32)return imread, label_tensor, labeldataset = tf.data.Dataset.from_tensor_slices((img_paths_tensor, labels_tensor, labels)) \.map(lambda item1, item2, item3: tf.py_func(map_func, [item1, item2, item3], [tf.float32, tf.int32, tf.string]),num_parallel_calls=8) \.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

定义Encoder、Decoder和Optimizer ,loss函数

encoder = Encoder(units, BATCH_SIZE)
decoder = Decoder(vocab_size, embedding_dim, units, BATCH_SIZE)optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)def loss_function(real, pred):mask = 1 - np.equal(real, 0)loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * maskreturn tf.reduce_mean(loss_)

开启训练

checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder)EPOCHS = 100for epoch in range(EPOCHS):start = time.time()total_loss = 0for (batch, (inp, targ, ground_truths)) in enumerate(dataset):loss = 0results = np.zeros((BATCH_SIZE, targ.shape[1] - 1), np.int32)with tf.GradientTape() as tape:enc_output, enc_hidden = encoder(inp)dec_hidden = enc_hiddendec_input = tf.expand_dims([label_lang.word2idx['']] * BATCH_SIZE, 1)# Teacher forcing - feeding the target as the next inputfor t in range(1, targ.shape[1]):# passing enc_output to the decoderpredictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)predicted_id = tf.argmax(predictions, axis=-1).numpy()results[:, t - 1] = predicted_id# result = [result[i] + label_lang.idx2word[predicted_id[i]] for i in range(BATCH_SIZE)]loss += loss_function(targ[:, t], predictions)# using teacher forcingdec_input = tf.expand_dims(targ[:, t], 1)batch_loss = (loss / int(targ.shape[1]))total_loss += batch_lossvariables = encoder.variables + decoder.variablesgradients = tape.gradient(loss, variables)optimizer.apply_gradients(zip(gradients, variables))preds = [process_result(result, label_lang) for result in results]ground_truths = [l.numpy().decode() for l in ground_truths]acc = compute_accuracy(ground_truths, preds)if batch % 1 == 0:print('Epoch {} Batch {} Loss {:.4f} Mean Loss {:.4f} acc {:f}'.format(epoch + 1, batch,batch_loss.numpy(),total_loss / (batch + 1),acc))if batch % 10 == 0:for i in range(5):print("real:{:s} pred:{:s} acc:{:f}".format(ground_truths[i], preds[i],compute_accuracy([ground_truths[i]], [preds[i]])))# saving (checkpoint) the model every 2 epochsif (epoch + 1) % 2 == 0:checkpoint.save(file_prefix=checkpoint_prefix)print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

测试代码

import osfrom config import cfg
from lang_dict.lang import LanguageIndex
from net.net import *
from utils.img_utils import *os.environ["CUDA_VISIBLE_DEVICES"] = "1"label_lang = LanguageIndex()
vocab_size = len(label_lang.word2idx)BATCH_SIZE = 1
embedding_dim = cfg.EMBEDDING_DIM
units = cfg.UNITSencoder = Encoder(units, BATCH_SIZE)
decoder = Decoder(vocab_size, embedding_dim, units, BATCH_SIZE)checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(encoder=encoder, decoder=decoder)checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))def evaluate(encoder, decoder, img_path, label_lang):img = process_img(img_path)enc_output, enc_hidden = encoder(np.expand_dims(img, axis=0))dec_hidden = enc_hiddendec_input = tf.expand_dims([label_lang.word2idx['']] * BATCH_SIZE, 1)results = np.zeros((BATCH_SIZE, 25), np.int32)for t in range(1, 25):# passing enc_output to the decoderpredictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)predicted_id = tf.argmax(predictions, axis=-1).numpy()results[:, t - 1] = predicted_iddec_input = tf.expand_dims(predicted_id, 1)preds = [process_result(result, label_lang) for result in results]print("pred :" + preds[0])img_path = "./sample/1_bridleway_9530.jpg"evaluate(encoder=encoder, decoder=decoder, img_path=img_path, label_lang=label_lang)

添加attention后,crnn收敛非常迅速,基本一个epoch就能基本收敛

 

全部代码


推荐阅读
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • Spring特性实现接口多类的动态调用详解
    本文详细介绍了如何使用Spring特性实现接口多类的动态调用。通过对Spring IoC容器的基础类BeanFactory和ApplicationContext的介绍,以及getBeansOfType方法的应用,解决了在实际工作中遇到的接口及多个实现类的问题。同时,文章还提到了SPI使用的不便之处,并介绍了借助ApplicationContext实现需求的方法。阅读本文,你将了解到Spring特性的实现原理和实际应用方式。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文介绍了一个在线急等问题解决方法,即如何统计数据库中某个字段下的所有数据,并将结果显示在文本框里。作者提到了自己是一个菜鸟,希望能够得到帮助。作者使用的是ACCESS数据库,并且给出了一个例子,希望得到的结果是560。作者还提到自己已经尝试了使用"select sum(字段2) from 表名"的语句,得到的结果是650,但不知道如何得到560。希望能够得到解决方案。 ... [详细]
  • 本文详细介绍了Spring的JdbcTemplate的使用方法,包括执行存储过程、存储函数的call()方法,执行任何SQL语句的execute()方法,单个更新和批量更新的update()和batchUpdate()方法,以及单查和列表查询的query()和queryForXXX()方法。提供了经过测试的API供使用。 ... [详细]
  • springmvc学习笔记(十):控制器业务方法中通过注解实现封装Javabean接收表单提交的数据
    本文介绍了在springmvc学习笔记系列的第十篇中,控制器的业务方法中如何通过注解实现封装Javabean来接收表单提交的数据。同时还讨论了当有多个注册表单且字段完全相同时,如何将其交给同一个控制器处理。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了Android 7的学习笔记总结,包括最新的移动架构视频、大厂安卓面试真题和项目实战源码讲义。同时还分享了开源的完整内容,并提醒读者在使用FileProvider适配时要注意不同模块的AndroidManfiest.xml中配置的xml文件名必须不同,否则会出现问题。 ... [详细]
author-avatar
拍友2502878393
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有