训练用的网络见上篇博客
Tensorflow 直接对验证码进行3通道卷积后识别
对于这上篇博客的网络稍作修改,利于Java调用
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random
train_data_dir = r'C:\Users\HUPENG\Desktop\check_code_crack\check_code\train'
test_data_dir = r''
train_file_name_list = os.listdir(train_data_dir)
def gen_train_data(batch_size=32):
selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
x_data = []
y_data = []
for selected_train_file_name in selected_train_file_name_list:
captcha_image = Image.open(train_data_dir + "/" + selected_train_file_name)
captcha_image_np = np.array(captcha_image)
x_data.append(captcha_image_np)
y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
x_data = np.array(x_data)
y_data = np.array(y_data)
return x_data, y_data
X = tf.placeholder(tf.float32, name="input")
keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
x_reshape = tf.reshape(X, (-1, 218, 82, 3))
w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 16]))
b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
conv1 = tf.nn.dropout(conv1, keep_prob)
w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.dropout(conv2, keep_prob)
w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv3 = tf.nn.dropout(conv3, keep_prob)
w_d = tf.Variable(w_alpha * tf.random_normal([28 * 11 * 16, 1024]))
b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
w_out = tf.Variable(w_alpha * tf.random_normal([1024, 5 * 10]))
b_out = tf.Variable(b_alpha * tf.random_normal([5 * 10]))
out = tf.add(tf.matmul(dense, w_out), b_out)
out = tf.reshape(out, (-1, 5, 10))
out = tf.nn.softmax(out)
out = tf.argmax(out, 2)
out = tf.cast(out, tf.float32, name="output")
return out
cnn = net()
def train():
saver = tf.train.Saver()
with tf.Session() as sess:
step = 0
tf.global_variables_initializer().run()
while True:
x_data, y_data = gen_train_data(64)
x_data = np.reshape(x_data, (-1))
loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})
print(loss_)
if loss_ <0.001:
saver.save(sess, "./crack_capcha.model", global_step=step)
print("save model successful!")
break
step += 1
def exportModel():
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "crack_capcha.model-941")
from tensorflow.python.framework.graph_util import convert_variables_to_constants
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
if __name__=='__main__':
exportModel()
print("ok")
导出模型文件为model.pb
Java工程的依赖如下:
整个的项目结构如下:
下面写Java端的调用
ECardCaptchaCrack.java
package me.hupeng.sdk.ecardcaptchacrack;
import net.coobird.thumbnailator.Thumbnails;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class ECardCaptchaCrack {
public static final int MAGIC_NUMBER = 20418991;
static byte[] graphDef = new byte[MAGIC_NUMBER];
static {
try {
BufferedInputStream bis = new BufferedInputStream(ClassLoader.getSystemClassLoader().getResourceAsStream("model.pb"));
int len = bis.read(graphDef, 0, MAGIC_NUMBER);
bis.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public static String crack(String filepath){
File file = new File(filepath);
if (!file.exists()){
return "";
}
String result = "";
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
BufferedImage im = null;
try {
im = Thumbnails.of(filepath).forceSize(218,82).outputFormat("bmp").asBufferedImage();
} catch (IOException e) {
e.printStackTrace();
return "";
}
Raster raster = im.getData();
float [] temp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
float [] pixels = raster.getPixels(0,0,raster.getWidth(),raster.getHeight(),temp);
Tensor input = Tensor.create(pixels, Float.class);
try (Session s = new Session(g);
Tensor output = s.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
float[][] output2floatArray = new float[1][5];
output.copyTo(output2floatArray);
for(int i=0; i<5; i++){
result += "" + (int)(output2floatArray[0][i]);
}
}
return result;
}
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
}
Main.java
package me.hupeng.sdk.ecardcaptchacrack;
public class Main {
public static void main(String[] args) {
System.out.println(ECardCaptchaCrack.crack("1.jpg"));
}
}
上面的代码预测的是1.jpg这张图
程序输出为:
注:只能Java SE桌面级应用可以用,Android版本调用等待更新!