热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Java调用Tensorflow训练出来的模型

训练用的网络见上篇博客Tensorflow直接对验证码进行3通道卷积后识别对于这上篇博客的网络稍作修改,利于Java调用importtensorflowastfimp

训练用的网络见上篇博客

Tensorflow 直接对验证码进行3通道卷积后识别
对于这上篇博客的网络稍作修改,利于Java调用

import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random

train_data_dir = r'C:\Users\HUPENG\Desktop\check_code_crack\check_code\train'
test_data_dir = r''


train_file_name_list = os.listdir(train_data_dir)

def gen_train_data(batch_size=32):
    selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
    x_data = []
    y_data = []
    for selected_train_file_name in selected_train_file_name_list:
        captcha_image = Image.open(train_data_dir + "/" + selected_train_file_name)
        captcha_image_np = np.array(captcha_image)
        x_data.append(captcha_image_np)
        y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
    x_data = np.array(x_data)
    y_data = np.array(y_data)
    return x_data, y_data

X = tf.placeholder(tf.float32, name="input")
# Y = tf.placeholder(tf.int32)
# keep_prob = tf.placeholder(tf.float32)
# y_one_hot = tf.one_hot(Y, 10, 1, 0)
# y_one_hot = tf.cast(y_one_hot, tf.float32)
keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
    x_reshape = tf.reshape(X, (-1, 218, 82, 3))
    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 16]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv1 = tf.nn.dropout(conv1, keep_prob)

    w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv2 = tf.nn.dropout(conv2, keep_prob)

    w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)

    # Fully connected layer
    # 随机生成权重
    w_d = tf.Variable(w_alpha * tf.random_normal([28 * 11 * 16, 1024]))
    # 随机生成偏置
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
    dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    w_out = tf.Variable(w_alpha * tf.random_normal([1024, 5 * 10]))
    b_out = tf.Variable(b_alpha * tf.random_normal([5 * 10]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    out = tf.reshape(out, (-1, 5, 10))


    out = tf.nn.softmax(out)
    out = tf.argmax(out, 2)
    out = tf.cast(out, tf.float32, name="output")

    return out
cnn = net()
# loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=cnn, labels=y_one_hot))
# optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

def train():
    saver = tf.train.Saver()
    with tf.Session() as sess:
        step = 0
        tf.global_variables_initializer().run()
        while True:
            x_data, y_data = gen_train_data(64)
            x_data = np.reshape(x_data, (-1))
            loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
                                                           feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})

            print(loss_)
            if loss_ <0.001:
                saver.save(sess, "./crack_capcha.model", global_step=step)
                print("save model successful!")
                break

            # cnn_ = sess.run(cnn, feed_dict={Y:y_data, X:x_data})
            # print(cnn_.shape)
            # break
            step += 1
def exportModel():
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 恢复模型参数
        saver.restore(sess, "crack_capcha.model-941")
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
        with tf.gfile.FastGFile('model.pb', mode='wb') as f:
            f.write(output_graph_def.SerializeToString())

if __name__=='__main__':
    # exportModel()
    exportModel()
    # train()
    print("ok")

导出模型文件为model.pb
Java工程的依赖如下:
这里写图片描述
整个的项目结构如下:
这里写图片描述
下面写Java端的调用
ECardCaptchaCrack.java

package me.hupeng.sdk.ecardcaptchacrack;

import net.coobird.thumbnailator.Thumbnails;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

public class ECardCaptchaCrack {
    public static final int MAGIC_NUMBER = 20418991;
    static  byte[] graphDef = new byte[MAGIC_NUMBER];

    static {
        try {
            BufferedInputStream bis = new BufferedInputStream(ClassLoader.getSystemClassLoader().getResourceAsStream("model.pb"));
            int len = bis.read(graphDef, 0, MAGIC_NUMBER);
            bis.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String crack(String filepath){
        File file = new File(filepath);
        if (!file.exists()){
            return "";
        }

        String result = "";
        try (Graph g = new Graph()) {
            g.importGraphDef(graphDef);
            BufferedImage im = null;
            try {
                im = Thumbnails.of(filepath).forceSize(218,82).outputFormat("bmp").asBufferedImage();
            } catch (IOException e) {
                e.printStackTrace();
                return "";
            }
            Raster raster = im.getData();
            float [] temp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
            float [] pixels  = raster.getPixels(0,0,raster.getWidth(),raster.getHeight(),temp);
            Tensor input = Tensor.create(pixels, Float.class);
            try (Session s = new Session(g);
                 Tensor output = s.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
// System.out.println(output);
                float[][] output2floatArray = new float[1][5];
                output.copyTo(output2floatArray);
                for(int i=0; i<5; i++){
                    result += "" + (int)(output2floatArray[0][i]);
                }
            }
            return result;
        }
    }

    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }
}

Main.java

package me.hupeng.sdk.ecardcaptchacrack;

public class Main {

    public static void main(String[] args) {
        System.out.println(ECardCaptchaCrack.crack("1.jpg"));
    }
}

上面的代码预测的是1.jpg这张图
这里写图片描述
程序输出为:
这里写图片描述

注:只能Java SE桌面级应用可以用,Android版本调用等待更新!


推荐阅读
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 解决问题:1、批量读取点云las数据2、点云数据读与写出3、csf滤波分类参考:https:github.comsuyunzzzCSF论文题目ÿ ... [详细]
  • 通过使用 `pandas` 库中的 `scatter_matrix` 函数,可以有效地绘制出多个特征之间的两两关系。该函数不仅能够生成散点图矩阵,还能通过参数如 `frame`、`alpha`、`c`、`figsize` 和 `ax` 等进行自定义设置,以满足不同的可视化需求。此外,`diagonal` 参数允许用户选择对角线上的图表类型,例如直方图或密度图,从而提供更多的数据洞察。 ... [详细]
  • 在 Vue 应用开发中,页面状态管理和跨页面数据传递是常见需求。本文将详细介绍 Vue Router 提供的两种有效方式,帮助开发者高效地实现页面间的数据交互与状态同步,同时分享一些最佳实践和注意事项。 ... [详细]
  • iOS snow animation
    CTSnowAnimationView.hCTMyCtripCreatedbyalexon1614.Copyright©2016年ctrip.Allrightsreserved.# ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 本地存储组件实现对IE低版本浏览器的兼容性支持 ... [详细]
  • 如何使用 `org.opencb.opencga.core.results.VariantQueryResult.getSource()` 方法及其代码示例详解 ... [详细]
  • 优化Vite 1.0至2.0升级过程中遇到的某些代码块过大问题解决方案
    本文详细探讨了在将项目从 Vite 1.0 升级到 2.0 的过程中,如何解决某些代码块过大的问题。通过具体的编码示例,文章提供了全面的解决方案,帮助开发者有效优化打包性能。 ... [详细]
  • Python 序列图分割与可视化编程入门教程
    本文介绍了如何使用 Python 进行序列图的快速分割与可视化。通过一个实际案例,详细展示了从需求分析到代码实现的全过程。具体包括如何读取序列图数据、应用分割算法以及利用可视化库生成直观的图表,帮助非编程背景的用户也能轻松上手。 ... [详细]
  • 本文全面解析了 Python 中字符串处理的常用操作与技巧。首先介绍了如何通过 `s.strip()`, `s.lstrip()` 和 `s.rstrip()` 方法去除字符串中的空格和特殊符号。接着,详细讲解了字符串复制的方法,包括使用 `sStr1 = sStr2` 进行简单的赋值复制。此外,还探讨了字符串连接、分割、替换等高级操作,并提供了丰富的示例代码,帮助读者深入理解和掌握这些实用技巧。 ... [详细]
  • 利用 Python Socket 实现 ICMP 协议下的网络通信
    在计算机网络课程的2.1实验中,学生需要通过Python Socket编程实现一种基于ICMP协议的网络通信功能。与操作系统自带的Ping命令类似,该实验要求学生开发一个简化的、非标准的ICMP通信程序,以加深对ICMP协议及其在网络通信中的应用的理解。通过这一实验,学生将掌握如何使用Python Socket库来构建和解析ICMP数据包,并实现基本的网络探测功能。 ... [详细]
  • 分享一款基于Java开发的经典贪吃蛇游戏实现
    本文介绍了一款使用Java语言开发的经典贪吃蛇游戏的实现。游戏主要由两个核心类组成:`GameFrame` 和 `GamePanel`。`GameFrame` 类负责设置游戏窗口的标题、关闭按钮以及是否允许调整窗口大小,并初始化数据模型以支持绘制操作。`GamePanel` 类则负责管理游戏中的蛇和苹果的逻辑与渲染,确保游戏的流畅运行和良好的用户体验。 ... [详细]
  • 解决针织难题:R语言编程技巧与常见错误分析 ... [详细]
  • 在Android应用开发中,实现与MySQL数据库的连接是一项重要的技术任务。本文详细介绍了Android连接MySQL数据库的操作流程和技术要点。首先,Android平台提供了SQLiteOpenHelper类作为数据库辅助工具,用于创建或打开数据库。开发者可以通过继承并扩展该类,实现对数据库的初始化和版本管理。此外,文章还探讨了使用第三方库如Retrofit或Volley进行网络请求,以及如何通过JSON格式交换数据,确保与MySQL服务器的高效通信。 ... [详细]
author-avatar
mobiledu2502884677
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有