热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于ONNX在ML.NET中使用Pytorch训练的垃圾分类模型

ML.NET在经典机器学习范畴内,对分类、回归、异常检测等问题开发模型已经有非常棒的表现了,我之前的文章都有过介绍。当然我们希望在更高层次的领域加以使用,例如计算机视觉、自然语言处

ML.NET 在经典机器学习范畴内,对分类、回归、异常检测等问题开发模型已经有非常棒的表现了,我之前的文章都有过介绍。当然我们希望在更高层次的领域加以使用,例如计算机视觉、自然语言处理和信号处理等等领域。

图像识别是计算机视觉的一类分支,AI研发者们较为熟悉的是使用TensorFlow、Pytorch、Keras、MXNET等框架来训练深度神经网络模型,其中会涉及到CNN(卷积神经网络)、DNN(深度神经网络)的相关算法。

ML.NET 在较早期的版本是无法支持这类研究的,可喜的是最新的版本不但能很好地集成 TensorFlow 的模型做迁移学习,还可以直接导入 DNN 常见预编译模型:AlexNet、ResNet18、ResNet50、ResNet101 实现对图像的分类、识别等。

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

我特别想推荐的是,ML.NET 最新版本对 ONNX 的支持也是非常强劲,通过 ONNX 可以把众多其他优秀深度学习框架的模型引入到 .NET Core 运行时中,极大地扩充了 .NET 应用在智能认知服务的丰富程度。在 Microsoft Docs 中已经提供了一个基于 ONNX 使用 Tiny YOLOv2 做对象检测的例子。为了展现 ML.NET 在其他框架上的通用性,本文将介绍使用 Pytorch 训练的垃圾分类的模型,基于 ONNX 导入到 ML.NET 中完成预测。

在2019年9月华为云举办了一次人工智能大赛·垃圾分类挑战杯,首次将AI与环保主题结合,展现人工智能技术在生活中的运用。有幸我看到了本次大赛亚军方案的分享,并且在 github 上找到了开源代码,按照 README 说明,我用 Pytorch 训练出了一个模型,并保存为garbage.pt 文件。

生成 ONNX 模型

首先,我使用以下 Pytorch 代码来生成一个garbage.pt 对应的文件,命名为 garbage.onnx

torch_model = torch.load("garbage.pt") # pytorch模型加载
    batch_size = 1  #批处理大小
    input_shape = (3,224,224)   #输入数据

    # # set the model to inference mode
    torch_model.eval()

    x = torch.randn(batch_size, *input_shape, device='cuda')        # 生成张量
    export_onnx_file = "garbage.onnx"                    # 目的ONNX文件名
 
    
    torch.onnx.export(torch_model.module,
                        x,
                        export_onnx_file,
                        input_names=["input"],        # 输入名
                        output_names=["output"]    # 输出名

准备 ML.NET 项目

创建一个 .NET Core 控制台应用,按如下结构创建好合适的目录。assets 目录下的 images 子目录将放置待预测的图片,而 Model 子目录放入前一个步骤生成的 garbage.onnx 模型文件。

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

ImageNetData 和 ImageNetPrediction 类定义了输入和输出的数据结构。

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;

namespace GarbageDemo.DataStructures
{
    public class ImageNetData
    {
        [LoadColumn(0)]
        public string ImagePath;

        [LoadColumn(1)]
        public string Label;

        public static IEnumerable ReadFromFile(string imageFolder)
        {
            return Directory
               .GetFiles(imageFolder)
               .Where(filePath => Path.GetExtension(filePath) == ".jpg")
               .Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });

        }
    }

    public class ImageNetPrediction : ImageNetData
    {
        public float[] Score;

        public string PredictedLabelValue;
    }
}

 OnnxModelScorer 类定义了 ONNX 模型加载、打分预测的过程。注意 ImageNetModelSettings 的属性和第一步中指定的输入输出字段名要一致。

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Onnx;
using Microsoft.ML.Transforms.Image;
using GarbageDemo.DataStructures;

namespace GarbageDemo
{
    class OnnxModelScorer
    {
        private readonly string imagesFolder;
        private readonly string modelLocation;
        private readonly MLContext mlContext;


        public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
        {
            this.imagesFolder = imagesFolder;
            this.modelLocation = modelLocation;
            this.mlCOntext= mlContext;
        }

        public struct ImageNetSettings
        {
            public const int imageHeight = 224;
            public const int imageWidth = 224;    
            public const float Mean = 127;
            public const float Scale = 1;
            public const bool ChannelsLast = false;
        } 
        
        public struct ImageNetModelSettings
        {
            // input tensor name
            public const string ModelInput = "input";

            // output tensor name
            public const string ModelOutput = "output";
        }

        private ITransformer LoadModel(string modelLocation)
        {
            Console.WriteLine("Read model");
            Console.WriteLine($"Model location: {modelLocation}");
            Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");

            // Create IDataView from empty list to obtain input data schema
            var data = mlContext.Data.LoadFromEnumerable(new List());

            // Define scoring pipeline
            var pipeline = mlContext.Transforms.LoadImages(outputColumnName: ImageNetModelSettings.ModelInput, imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))                           
                            .Append(mlContext.Transforms.ResizeImages(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                        imageWidth: ImageNetSettings.imageWidth, 
                                                                        imageHeight: ImageNetSettings.imageHeight, 
                                                                        inputColumnName: ImageNetModelSettings.ModelInput,
                                                                        resizing: ImageResizingEstimator.ResizingKind.IsoCrop,
                                                                        cropAnchor: ImageResizingEstimator.Anchor.Center
                                                                        ))
                            .Append(mlContext.Transforms.ExtractPixels(outputColumnName: ImageNetModelSettings.ModelInput, interleavePixelColors: ImageNetSettings.ChannelsLast))
                            .Append(mlContext.Transforms.NormalizeGlobalContrast(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 inputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 ensureZeroMean : true, 
                                                                                 ensureUnitStandardDeviation: true, 
                                                                                 scale: ImageNetSettings.Scale))
                            .Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { ImageNetModelSettings.ModelOutput }, inputColumnNames: new[] { ImageNetModelSettings.ModelInput }));

            // Fit scoring pipeline
            var model = pipeline.Fit(data);

            return model;
        }

        private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
        {
            Console.WriteLine($"Images location: {imagesFolder}");
            Console.WriteLine("");
            Console.WriteLine("=====Identify the objects in the images=====");
            Console.WriteLine("");

            IDataView scoredData = model.Transform(testData);

            IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(ImageNetModelSettings.ModelOutput);

            return probabilities;
        }

        public IEnumerable<float[]> Score(IDataView data)
        {
            var model = LoadModel(modelLocation);

            return PredictDataUsingModel(data, model);
        }
    }
}

Program 类中定义了调用过程,完成预测结果呈现。

using GarbageDemo.DataStructures;
using Microsoft.ML;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace GarbageDemo
{
    class Program
    {
        static void Main(string[] args)
        {
            var assetsRelativePath = @"../../../assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
            var modelFilePath = Path.Combine(assetsPath, "Model", "garbage.onnx");
            var imagesFolder = Path.Combine(assetsPath, "images");// Initialize MLContext
            MLContext mlCOntext= new MLContext();

            try
            {
                // Load Data
                IEnumerable images = ImageNetData.ReadFromFile(imagesFolder);
                IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);

                // Create instance of model scorer
                var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);

                // Use model to score data
                IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);

                int index = 0;
                foreach (var probable in probabilities)
                {
                    var scores = Softmax(probable);

                    var (topResultIndex, topResultScore) = scores.Select((predictedClass, index) => (Index: index, Value: predictedClass))
                        .OrderByDescending(result => result.Value)
                        .First();
                    Console.WriteLine("图片:{3} \r\n 分类{2} {0}:{1}", labels[topResultIndex], topResultScore, topResultIndex, images.ElementAt(index).ImagePath);
                    Console.WriteLine("=============================");
                    index++;
                }

            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
            }

            Console.WriteLine("========= End of Process..Hit any Key ========");
            Console.ReadLine();
        }

        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }

        private static float[] Softmax(float[] values)
        {
            var maxVal = values.Max();
            var exp = values.Select(v => Math.Exp(v - maxVal));
            var sumExp = exp.Sum();

            return exp.Select(v => (float)(v / sumExp)).ToArray();
        }

        private static string[] labels = new string[]
        {
            "其他垃圾/一次性快餐盒",
            "其他垃圾/污损塑料",
            "其他垃圾/烟蒂",
            "其他垃圾/牙签",
            "其他垃圾/破碎花盆及碟碗",
            "其他垃圾/竹筷",
            "厨余垃圾/剩饭剩菜",
            "厨余垃圾/大骨头",
            "厨余垃圾/水果果皮",
            "厨余垃圾/水果果肉",
            "厨余垃圾/茶叶渣",
            "厨余垃圾/菜叶菜根",
            "厨余垃圾/蛋壳",
            "厨余垃圾/鱼骨",
            "可回收物/充电宝",
            "可回收物/包",
            "可回收物/化妆品瓶",
            "可回收物/塑料玩具",
            "可回收物/塑料碗盆",
            "可回收物/塑料衣架",
            "可回收物/快递纸袋",
            "可回收物/插头电线",
            "可回收物/旧衣服",
            "可回收物/易拉罐",
            "可回收物/枕头",
            "可回收物/毛绒玩具",
            "可回收物/洗发水瓶",
            "可回收物/玻璃杯",
            "可回收物/皮鞋",
            "可回收物/砧板",
            "可回收物/纸板箱",
            "可回收物/调料瓶",
            "可回收物/酒瓶",
            "可回收物/金属食品罐",
            "可回收物/锅",
            "可回收物/食用油桶",
            "可回收物/饮料瓶",
            "有害垃圾/干电池",
            "有害垃圾/软膏",
            "有害垃圾/过期药物",
            "可回收物/毛巾",
            "可回收物/饮料盒",
            "可回收物/纸袋"
        };

选择一张图片放到 images 目录中,运行结果如下:

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

有 0.88 的得分说明照片中的物品属于污损塑料,让我们看一下图片真相。

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

果然是相当准确 ,并且把周边的附属物都过滤掉了。

对于 ML.NET 训练深度神经网络模型支持更复杂的场景是不是更有信心了!


推荐阅读
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 本文介绍了在wepy中运用小顺序页面受权的计划,包含了用户点击作废后的从新受权计划。 ... [详细]
  • 本文介绍了在go语言中利用(*interface{})(nil)传递参数类型的原理及应用。通过分析Martini框架中的injector类型的声明,解释了values映射表的作用以及parent Injector的含义。同时,讨论了该技术在实际开发中的应用场景。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • Android中高级面试必知必会,积累总结
    本文介绍了Android中高级面试的必知必会内容,并总结了相关经验。文章指出,如今的Android市场对开发人员的要求更高,需要更专业的人才。同时,文章还给出了针对Android岗位的职责和要求,并提供了简历突出的建议。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 判断编码是否可立即解码的程序及电话号码一致性判断程序
    本文介绍了两个编程题目,一个是判断编码是否可立即解码的程序,另一个是判断电话号码一致性的程序。对于第一个题目,给出一组二进制编码,判断是否存在一个编码是另一个编码的前缀,如果不存在则称为可立即解码的编码。对于第二个题目,给出一些电话号码,判断是否存在一个号码是另一个号码的前缀,如果不存在则说明这些号码是一致的。两个题目的解法类似,都使用了树的数据结构来实现。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
author-avatar
--AppleChan--
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有