热门标签 | HotTags
当前位置:  开发笔记 > Android > 正文

AndroidTensorFlowLite深度学习识别手写数字mnistdemo

一.TensorFlowLite​TensorFlowLite特性.jpeg​TensorFlowLite是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlowLite支

一. TensorFlow Lite

Android TensorFlow Lite 深度学习识别手写数字mnist demo
Android TensorFlow Lite 深度学习识别手写数字mnist demo
Android TensorFlow Lite 深度学习识别手写数字mnist demo
Android TensorFlow Lite 深度学习识别手写数字mnist demo

TensorFlow Lite特性.jpeg

Android TensorFlow Lite 深度学习识别手写数字mnist demo
Android TensorFlow Lite 深度学习识别手写数字mnist demo

TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。

我们知道大多数的 AI 是在云端运算的,但是在移动端使用 AI 具有无网络延迟、响应更加及时、数据隐私等特性。

对于离线的场合,云端的 AI 就无法使用了,而此时可以在移动设备中使用 TensorFlow Lite。

二. tflite 格式

TensorFlow 生成的模型是无法直接给移动端使用的,需要离线转换成.tflite文件格式。

tflite 存储格式是 flatbuffers。

FlatBuffers 是由Google开源的一个免费软件库,用于实现序列化格式。它类似于Protocol Buffers、Thrift、Apache Avro。

因此,如果要给移动端使用的话,必须把 TensorFlow 训练好的 protobuf 模型文件转换成 FlatBuffers 格式。官方提供了 toco 来实现模型格式的转换。

三. 常用的 Java API

TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。

而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。后面的例子会看到如何使用 Interpreter。

四. TensorFlow Lite + mnist 数据集实现识别手写数字

mnist 是手写数字图片数据集,包含60000张训练样本和10000张测试样本。 测试集也是同样比例的手写数字数据。每张图片有28x28个像素点构成,每个像素点用一个灰度值表示,这里是将28x28的像素展开为一个一维的行向量(每行784个值)。

mnist 数据集获取地址:http://yann.lecun.com/exdb/mnist/

下面的 demo 中已经包含了 mnist.tflite 模型文件。(如果没有的话,需要自己训练保存成pb文件,再转换成tflite 格式)

对于一个识别类,首先需要初始化 TensorFlow Lite 解释器,以及输入、输出。

    // The tensorflow lite file
    private lateinit var tflite: Interpreter

    // Input byte buffer
    private lateinit var inputBuffer: ByteBuffer

    // Output array [batch_size, 10]
    private lateinit var mnistOutput: Array>

    init {

        try {
            tflite = Interpreter(loadModelFile(activity))

            inputBuffer = ByteBuffer.allocateDirect(
                    BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)
            inputBuffer.order(ByteOrder.nativeOrder())
            mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) }
            Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.")
        } catch (e: IOException) {
            Log.e(TAG, "IOException loading the tflite file failed.")
        }

    }

从 asserts 文件中加载 mnist.tflite 模型:

    /**
     * Load the model file from the assets folder
     */
    @Throws(IOException::class)
    private fun loadModelFile(activity: Activity): MappedByteBuffer {

        val fileDescriptor = activity.assets.openFd(MODEL_PATH)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

真正识别手写数字是在 classify() 方法:

val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))

classify() 方法包含了预处理用于初始化 inputBuffer、运行 mnist 模型、识别出数字。

    /**
     * Classifies the number with the mnist model.
     *
     * @param bitmap
     * @return the identified number
     */
    fun classify(bitmap: Bitmap): Int {

        if (tflite == null) {
            Log.e(TAG, "Image classifier has not been initialized; Skipped.")
        }

        preProcess(bitmap)
        runModel()
        return postProcess()
    }

    /**
     * Converts it into the Byte Buffer to feed into the model
     *
     * @param bitmap
     */
    private fun preProcess(bitmap: Bitmap?) {

        if (bitmap == null || inputBuffer == null) {
            return
        }

        // Reset the image data
        inputBuffer.rewind()

        val width = bitmap.width
        val height = bitmap.height

        // The bitmap shape should be 28 x 28
        val pixels = IntArray(width * height)
        bitmap.getPixels(pixels, 0, width, 0, 0, width, height)

        for (i in pixels.indices) {
            // Set 0 for white and 255 for black pixels
            val pixel = pixels[i]
            // The color of the input is black so the blue channel will be 0xFF.
            val channel = pixel and 0xff
            inputBuffer.putFloat((0xff - channel).toFloat())
        }
    }

    /**
     * Run the TFLite model
     */
    private fun runModel() = tflite.run(inputBuffer, mnistOutput)

    /**
     * Go through the output and find the number that was identified.
     *
     * @return the number that was identified (returns -1 if one wasn't found)
     */
    private fun postProcess(): Int {

        for (i in 0 until mnistOutput[0].size) {
            val value = mnistOutput[0][i]
            if (value == 1f) {
                return i
            }
        }

        return -1
    }

对于 Android 有一个地方需要注意,必须在 app 模块的 build.gradle 中添加如下的语句,否则无法加载模型。

android {
    ......
    aaptOptions {
        noCompress "tflite"
    }
}

demo 运行效果如下:

识别手写数字5.png

Android TensorFlow Lite 深度学习识别手写数字mnist demo

识别手写数字7.png

五. 总结

本文只是 TF Lite 的初探,很多细节并没有详细阐述。应该会在未来的文章中详细介绍。

本文 demo 的 github 地址:https://github.com/xiaobingchan/TFLite-MnistDemo

当然,也可以跑一下官方的例子:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/examples/android/app

任何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX

Android TensorFlow Lite 深度学习识别手写数字mnist demo

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论

推荐阅读
  • 解决TensorFlow CPU版本安装中的依赖问题
    本文记录了在安装CPU版本的TensorFlow过程中遇到的依赖问题及解决方案,特别是numpy版本不匹配和动态链接库(DLL)错误。通过详细的步骤说明和专业建议,帮助读者顺利安装并使用TensorFlow。 ... [详细]
  • 新手指南:在Windows 10上搭建深度学习与PyTorch开发环境
    本文详细记录了一名新手在Windows 10操作系统上搭建深度学习环境的过程,包括安装必要的软件和配置环境变量等步骤,旨在帮助同样初入该领域的读者避免常见的错误。 ... [详细]
  • 在Ubuntu 16.04中使用Anaconda安装TensorFlow
    本文详细介绍了如何在Ubuntu 16.04系统上通过Anaconda环境管理工具安装TensorFlow。首先,需要下载并安装Anaconda,然后配置环境变量以确保系统能够识别Anaconda命令。接着,创建一个特定的Python环境用于安装TensorFlow,并通过指定的镜像源加速安装过程。最后,通过一个简单的线性回归示例验证TensorFlow的安装是否成功。 ... [详细]
  • 吴恩达推出TensorFlow实践课程,Python基础即可入门,四个月掌握核心技能
    量子位报道,deeplearning.ai最新发布了TensorFlow实践课程,适合希望使用TensorFlow开发AI应用的学习者。该课程涵盖机器学习模型构建、图像识别、自然语言处理及时间序列预测等多个方面。 ... [详细]
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • 本文详细介绍了 TensorFlow 的入门实践,特别是使用 MNIST 数据集进行数字识别的项目。文章首先解析了项目文件结构,并解释了各部分的作用,随后逐步讲解了如何通过 TensorFlow 实现基本的神经网络模型。 ... [详细]
  • TensorFlow 2.0 中的 Keras 数据归一化实践
    数据预处理是机器学习任务中的关键步骤,特别是在深度学习领域。通过将数据归一化至特定范围,可以在梯度下降过程中实现更快的收敛速度和更高的模型性能。本文探讨了如何使用 TensorFlow 2.0 和 Keras 进行有效的数据归一化。 ... [详细]
  • 吴裕雄探讨混合神经网络模型在深度学习中的应用:结合RNN与CNN优化网络性能
    本文由吴裕雄撰写,深入探讨了如何利用Python、Keras及TensorFlow构建混合神经网络模型,特别是通过结合递归神经网络(RNN)和卷积神经网络(CNN),实现对网络运行效率的有效提升。 ... [详细]
  • 本文探讨了如何在TensorFlow中使用张量来处理和分析数字图像,特别是通过具体的代码示例展示了张量在图像处理中的作用。 ... [详细]
  • 机器学习核心概念与技术
    本文系统梳理了机器学习的关键知识点,涵盖模型评估、正则化、线性模型、支持向量机、决策树及集成学习等内容,并深入探讨了各算法的原理和应用场景。 ... [详细]
  • 本文介绍了一个使用Keras框架构建的卷积神经网络(CNN)实例,主要利用了Keras提供的MNIST数据集以及相关的层,如Dense、Dropout、Activation等,构建了一个具有两层卷积和两层全连接层的CNN模型。 ... [详细]
  • 本文详细介绍了C++标准模板库(STL)中各容器的功能特性,并深入探讨了不同容器操作函数的异常安全性。 ... [详细]
  • TensorFlow核心函数解析与应用
    本文详细介绍了TensorFlow中几个常用的基础函数及其应用场景,包括常量创建、张量扩展以及二维卷积操作等,旨在帮助开发者更好地理解和使用这些功能。 ... [详细]
  • 本文探讨了如何在Python中处理长数据的完全显示问题,包括numpy数组、pandas DataFrame以及tensor类型的完整输出设置。 ... [详细]
  • 本文介绍了TensorFlow中的tf.reverse函数,该函数用于对张量进行维度上的反转操作。我们将通过具体示例代码来展示其用法和效果。 ... [详细]
author-avatar
帝姬
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有