热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

cnn识别不定长英文验证码

本程序作为一个基础版的CNN使用教程,以识别简单的英文验证码作为目标完成一个简单的实例。在这个实例中,我们会涉及到以下三步,并通过这三部曲

       本程序作为一个基础版的CNN使用教程,以识别简单的英文验证码作为目标完成一个简单的实例。在这个实例中,我们会涉及到以下三步,并通过这三部曲来带大家体验深度学习的魅力。

仅2分,可直接运行的程序:https://download.csdn.net/download/qq_32791307/10651274

     1. 对数据的基本读取、处理并生成batch 

     2. 搭建简单的CNN模型 

     3. 训练并得出结果

          2.  其次有了样本之后,我们需要对样本进行以下的操作。

          第一步:获取图片和图片名称的代码,非常简单,注释如下。

def get_image_name():#获得目录下所有图片all_image = os.listdir('D:/tensorflow3_captcha/train/')#打散训练集random_file = random.randint(0,8000) #拓展名获取, eg: 123.txt 分为[0] 123 ,[1] txtbase = os.path.basename('D:/tensorflow_captcha/train/' + all_image[random_file])name = os.path.splitext(base)[0]#读取图片并转为灰度图image = cv2.imread('D:/tensorflow_captcha/train/' + all_image[random_file],cv2.IMREAD_GRAYSCALE)
# cv2.imshow("tu",image)
# cv2.waitKey(0)
# print(name)return name,image

            第二步:确定字母次序为0-9a-zA-Z得到字母对应位置标签. 如 a 对应标签为 11

#ord(c)随机给定一个无符号数,ord('0')=48
#字符固定位置
def char2pos(c): if c =='_':k=62return kk= ord(c)-48# ord(c)为0-9时的位置
#ord(c)为大写字母时的位置,因为0-9占了10位,所以是ord(c)-ord('A')+10=ord(c)-55if k > 9: #如ord('C')=67-55=12k = ord(c) - 55
#ord(c)为小写字母时的位置,0-z占了35位 if k > 35: #ord('c')=99-61=38k = ord(c) - 61if k > 61:raise ValueError('No Map')return k

            第三步:名字标签与one-hot编码的形式互相转换。(注:1,2,3对应one-hot示例 [ [0 ,1 ,0 ,0], [0 ,0 ,1 ,0], [0 ,0 ,0 ,1]])

#名字转向量标签
def name2vec(name):vector = np.zeros(max_captcha*char_set) for i, c in enumerate(name): idx = i * char_set + char2pos(c)vector[idx] = 1return vector #return 不能忘!!导致loss = nan
# print (vector);print(idx)#向量标签转名字,为了结果可以对应,也可以名字向量互转用字典代替。
def vec2name(vec):char_pos = vec.nonzero()[0]name=[]for i, c in enumerate(char_pos):char_idx = c % char_setif char_idx <10:char_code = char_idx + ord(&#39;0&#39;)elif char_idx <36:char_code = char_idx - 10 + ord(&#39;A&#39;)elif char_idx <62:char_code = char_idx- 36 + ord(&#39;a&#39;)elif char_idx == 62:char_code = ord(&#39;_&#39;)else:raise ValueError(&#39;error&#39;)name.append(chr(char_code))return "".join(name)

            第四步:制作batch。(注:batch_size 一般大小为2的n次方,常用大小为32,64)

def get_next_batch(batch_size):batch_x = np.zeros([batch_size,image_height*image_width])batch_y = np.zeros([batch_size,max_captcha*char_set])for i in range(batch_size): #格式注意range,np.zeros([])name,image = get_image_name()batch_x[i,:] = image.flatten()/255batch_y[i,:] = name2vec(name)return batch_x,batch_y

二. 搭建简单的CNN模型

        通常,我会将搭建过程分为前后两期,前期定义输入变量,定义weight、bias 、conv2d、max_pool等所需要的‘积木块’。后期开始堆积积木块。各大著名网络结构可以看为积木搭建的图纸,只要我们有安装图纸,前期只要定义好各类所需的‘积木块’,后期就可以按图索骥就好。本次只是简单的堆积了一个模型,大家也可以去体验更多的网络结构。下面,我们开始吧!

 

        第一步:定义各类变量

X = tf.placeholder(tf.float32,[None,image_height*image_width])
Y = tf.placeholder(tf.float32,[None,max_captcha*char_set])
y_in = tf.sparse_placeholder(tf.int32)
seq_length=tf.placeholder(tf.int32,[None])
keep_prob = tf.placeholder(tf.float32) # dropout#这里的learn_rate比较详细,不麻烦的话也可以直接用learn_rate=0.0001作为学习速率
global_step = tf.Variable(0, trainable=False)
learn_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,DECAY_STEPS,LEARNING_RATE_DECAY_FACTOR,staircase=True)#define weight biaes
def weight(shape):#tf.truncated_normal(shape, mean, stddev)initial = 0.01*tf.random_normal(shape)return tf.Variable(initial)def biases(shape):initial = 0.1*tf.random_normal(shape) #tf.constant(0.1,shape=shape)return tf.Variable(initial)#define conv and pool ,padding有&#39;SAME&#39;和&#39;VALID&#39;两种
def conv2d(x,w):return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding=&#39;SAME&#39;)
def max_pool(x):#ksize是pool的窗口大小=[1,height,width,1]也就是卷积核大小;return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding=&#39;SAME&#39;)#因为是2x2的池化,所以输出是池化一次就/2 ,和权重的卷积取多少无关

         第二步:积木堆积,本次所用图片大小为80*170, 所以第一次的卷积输出为40*85

模仿的VGG16

def cnn(weight,biases): img_x = tf.reshape(X,[-1,image_height,image_width,1])#### conv layer 1 ####w_conv1 = weight([3,3,1,32]) #3X3的卷积核,第一层输出32b_conv1 = biases([32])h_conv1 = tf.nn.relu(conv2d(img_x,w_conv1)+b_conv1)h_pool1 = max_pool(h_conv1) #outsize = 40*85 conv1 = tf.nn.dropout(h_pool1, keep_prob)#### conv layer 2 #### w_conv2 = weight([3,3,32,64]) #3X3的卷积核,第2层输出64b_conv2 = biases([64])h_conv2 = tf.nn.relu(conv2d(conv1,w_conv2)+b_conv2)h_pool2 = max_pool(h_conv2) #20,42.5 conv2 = tf.nn.dropout(h_pool2, keep_prob) #### conv layer 3 ####w_conv3 = weight([3,3,64,64]) #3X3的卷积核,第2层输出64b_conv3 = biases([64])h_conv3 = tf.nn.relu(conv2d(conv2,w_conv3)+b_conv3)h_pool3 = max_pool(h_conv3) #10, 21.25--->10,22 conv3 = tf.nn.dropout(h_pool3, keep_prob) ## ## Fully connected layer 1 ####w_fc1 = weight([10*22*64,1024])b_fc1 = biases([1024])
# h_pool3_flat = tf.reshape(h_pool3,[-1,10*22*64])h_pool3_flat = tf.reshape(conv3, [-1, w_fc1.get_shape().as_list()[0]])h_fc1 = tf.nn.relu(tf.matmul(h_pool3_flat,w_fc1)+b_fc1)#[64,1024]h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)## ## out 输出特征 ####w_out = weight([1024,max_captcha*char_set])b_out = biases([max_captcha*char_set])#out用的sigmoid,loss对应sigmoid#out = [64,1024]*[1024,378]+378 =[64,378]out = tf.matmul(h_fc1_drop,w_out)+b_out #sigmoid ,注意行列顺序,HXWreturn out , h_fc1

三. 训练并验证

         1. 开始训练前容易犯的错误是没有初始化变量。

         2.sess.close不加也可以,但是加上是个好习惯。

        3.其实也可以单独定义 loss,accuracy。

        4.及时使用tf.train.Saver()来储存模型。这样可以多次训练模型。用saver.restore即可快速恢复训练。


def train_cnn():output, conv_3 = cnn(weight,biases) #output 64*378
# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=Y))loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y))train_step = tf.train.AdamOptimizer(0.001).minimize(loss)predict = tf.reshape(output,[-1,max_captcha,char_set])label = tf.reshape(Y,[-1,max_captcha,char_set])
# print(label,&#39;\n&#39;,predict)max_idx_pre = tf.argmax(predict,2) #沿char_set找最大值索引即63个里最大的。max_idx_lab = tf.argmax(label,2)accuracy = tf.metrics.accuracy(labels=max_idx_lab,predictiOns=max_idx_pre) #tuplesaver = tf.train.Saver()sess = tf.Session()init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_opsess.run(init_op)step = 0while True:
######需要恢复模型,继续训练时可使用以下注释部分
# model_file=tf.train.latest_checkpoint(&#39;model/&#39;)
# saver.restore(sess,model_file ) batch_x,batch_y = get_next_batch(64)_,loss_ ,out= sess.run([train_step,loss,output] ,feed_dict={X:batch_x,Y:batch_y,keep_prob: 0.6})
# print(&#39;step:&#39;,step)
# print(out.shape)if step%100 == 0:batch_x_test,batch_y_test = get_next_batch(64)acc = sess.run(accuracy,feed_dict={X:batch_x_test,Y:batch_y_test, keep_prob: 1.})[1]print(&#39;train loss: &#39; , loss_,&#39;train acc: &#39;, acc)if acc > 0.8 :saver.save(sess,"model/my_captcha.ckpt", global_step=step)break step += 1sess.close()
train_cnn()

四. 训练过程中可能出现的问题

1. 训练的 Loss和Acc 可能值会上下波动,这是正常现象,只要整体Loss趋势变小,Acc变大即可

2. 训练时间需要长一些,因为样本比较复杂,且没预处理,请耐心等待。

3. 如果训练很长时间测试的Acc却达不到很高的水平,可以尝试增加样本数量。


推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 阿里Treebased Deep Match(TDM) 学习笔记及技术发展回顾
    本文介绍了阿里Treebased Deep Match(TDM)的学习笔记,同时回顾了工业界技术发展的几代演进。从基于统计的启发式规则方法到基于内积模型的向量检索方法,再到引入复杂深度学习模型的下一代匹配技术。文章详细解释了基于统计的启发式规则方法和基于内积模型的向量检索方法的原理和应用,并介绍了TDM的背景和优势。最后,文章提到了向量距离和基于向量聚类的索引结构对于加速匹配效率的作用。本文对于理解TDM的学习过程和了解匹配技术的发展具有重要意义。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • mapreduce数据去重的实现方法
    本文介绍了利用mapreduce实现数据去重的方法,同时还介绍了人工智能AI领域中常用的框架和工具,包括Keras、PyTorch、MXNet、TensorFlow和PaddlePaddle,并提供了深度学习实战的代码下载链接。 ... [详细]
  • Python15行代码实现免费发送手机短信,推送消息「建议收藏」
    Python15行代码实现免费发 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 知识图谱表示概念:知识图谱是由一些相互连接的实体和他们的属性构成的。换句话说,知识图谱是由一条条知识组成,每条知识表示为一个SPO三元组(Subject-Predicate-Obj ... [详细]
  • adfs是什么_培训与开发的概念
    adfs是什么_培训与开发的概念(如您转载本文,必须标明本文作者及出处。如有任何疑问请与我联系me@nap7.com)ADFS相关开发技术的中文资料相对匮乏,之前在弄这个东西的时候 ... [详细]
  • 3年半巨亏242亿!商汤高估了深度学习,下错了棋?
    转自:新智元三年半研发开支近70亿,累计亏损242亿。AI这门生意好像越来越不好做了。近日,商汤科技已向港交所递交IPO申请。招股书显示& ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 当写稿机器人真有了观点和感情,我们是该高兴还是恐惧?
    目前,写稿机器人多是撰写以数据为主的稿件,当它们能够为文章注入观点之时,这些观点真的是其所“想”吗?最近,《南 ... [详细]
author-avatar
VASTEw
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有