热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

手把手教你用深度学习做物体检测(四):模型使用

上一篇《手把手教你用深度学习做物体检测(三):模型训练》中介绍了如何使用yolov3训练我们自己的物体检测模型,本篇文章将重点介绍如何使用我们训练好的模型来检测图片或视频中的物体

上一篇《手把手教你用深度学习做物体检测(三):模型训练》中介绍了如何使用yolov3训练我们自己的物体检测模型,本篇文章将重点介绍如何使用我们训练好的模型来检测图片或视频中的物体。

  如果你看过了上一篇文章,那么就知道我们用的是 AlexeyAB/darknet项目,该项目虽然提供了物体检测的方法,分别是基于c++和python编写的物体检测代码,但是有几个问题如下:

  • 都不支持中文显示。
  • 都没有显示置信度。
  • 程序检测框样式都不够友好。
  • python编写的物体检测代码执行总是报类型相关错误,估计是底层c++程序的问题。

  其中,中文显示乱码的问题和opencv有关,网上也有很多文章有所介绍,但是都十分繁琐,所以我基于python,借鉴 qqwweee/keras-yolo3项目的代码,重新写了一套物体检测程序,主要思想是用python的PIL库代替opencv来绘制检测信息到图像上,当然还有其它一些细节改动,就不一一说明了,直接上代码:
  darknet.py文件主要是修改了detect_image方法

def detect_image(class_names, net, meta, im, thresh=.5, hier_thresh=.5, nms=.45, debug=False):
    num = c_int(0)
    if debug: print("Assigned num")
    pnum = pointer(num)
    if debug: print("Assigned pnum")
    predict_image(net, im)
    if debug: print("did prediction")

    dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum, 0)
    if debug: print("Got dets")
    num = pnum[0]
    if debug: print("got zeroth index of pnum")
    if nms:
        do_nms_sort(dets, num, meta.classes, nms)
    if debug: print("did sort")
    res = []
    if debug: print("about to range")
    for j in range(num):
        if debug: print("Ranging on " + str(j) + " of " + str(num))
        if debug: print("Classes: " + str(meta), meta.classes, meta.names)
        for i in range(meta.classes):
            if debug: print("Class-ranging on " + str(i) + " of " + str(meta.classes) + "= " + str(dets[j].prob[i]))
            if dets[j].prob[i] > 0.0:
                b = dets[j].bbox
                if altNames is None:
       # nameTag = meta.names[i] 该步骤会导致段错误,初步判断应该是和c++程序有关,所以直接传入类别列表参数,以绕过该问题。
                    nameTag = class_names[i]
                    print(nameTag)
                else:
                    nameTag = altNames[i]
                    print(nameTag)
                if debug:
                    print("Got bbox", b)
                    print(nameTag)
                    print(dets[j].prob[i])
                    print((b.x, b.y, b.w, b.h))
                res.append((nameTag, dets[j].prob[i], (b.x, b.y, b.w, b.h)))
    if debug: print("did range")
    res = sorted(res, key=lambda x: -x[1])
    if debug: print("did sort")
    free_detections(dets, num)
    if debug: print("freed detections")
    return res

  添加darknet_video_custom.py,内容如下

# -*- coding: utf-8 -*-
"""
本模块使用yolov3模型探测目标在图片或视频中的位置
"""
__author__ = \'程序员一一涤生\'import colorsys
import os
from timeit import default_timer as timer
import cv2
import numpy as np
from PIL import ImageDraw, ImageFont, Image
import darknet

def _convertBack(x, y, w, h):
    xmin = int(round(x - (w / 2)))
    xmax = int(round(x + (w / 2)))
    ymin = int(round(y - (h / 2)))
    ymax = int(round(y + (h / 2)))
    return xmin, ymin, xmax, ymax

def letterbox_image(image, size):
    \'\'\'resize image with unchanged aspect ratio using padding\'\'\'
    iw, ih = image.size
    w, h = size
    scale = min(w / iw, h / ih)
    nw = int(iw * scale)
    nh = int(ih * scale)
    image = image.resize((nw, nh), Image.BICUBIC)
    new_image = Image.new(\'RGB\', size, (128, 128, 128))
    new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
    return new_image

class YOLO(object):
    _defaults = {
        "configPath": "names-data/yolo-obj.cfg",
        "weightPath": "names-data/backup/yolo-obj_3000.weights",
        "metaPath": "names-data/voc.data",
        "classes_path": "names-data/voc.names",
        "thresh": 0.3,
        "iou_thresh": 0.5,
        # "model_image_size": (416, 416),
        # "model_image_size": (608, 608),
        "model_image_size": (800, 800),
        "gpu_num": 1,
    }

    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)  # set up default values
        self.__dict__.update(kwargs)  # and update with user overrides
        self.class_names = self._get_class()
        self.colors = self._get_colors()
        self.netMain = darknet.load_net_custom(self.configPath.encode("ascii"), self.weightPath.encode("ascii"), 0,
                                               1)  # batch size = 1
        self.metaMain = darknet.load_meta(self.metaPath.encode("ascii"))
        self.altNames = self._get_alt_names()

    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path, encoding="utf-8") as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    def _get_colors(self):
        class_names = self._get_class()
        # Generate colors for drawing bounding boxes.
        hsv_tuples = [(x / len(class_names), 1., 1.)
                      for x in range(len(class_names))]
        colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
        np.random.seed(10101)  # Fixed seed for consistent colors across runs.
        np.random.shuffle(colors)  # Shuffle colors to decorrelate adjacent classes.
        np.random.seed(None)  # Reset seed to default.
        return colors

    def _get_alt_names(self):
        try:
            with open(self.metaPath) as metaFH:
                metaContents = metaFH.read()
                import re
                match = re.search("names *= *(.*)$", metaContents, re.IGNORECASE | re.MULTILINE)
                if match:
                    result = match.group(1)
                else:
                    result = None
                try:
                    if os.path.exists(result):
                        with open(result) as namesFH:
                            namesList = namesFH.read().strip().split("\n")
                            altNames = [x.strip() for x in namesList]
                except TypeError:
                    pass
        except Exception:
            pass
        return altNames

    def cvDrawBoxes(self, detections, image):
        # 字体相关设置,包括字体文件路径、字体大小
        fOnt= ImageFont.truetype(fOnt=\'font/simfang.ttf\',
                                  size=np.floor(3e-2 * image.size[1] + 0.5).astype(\'int32\'))
        # 检测框的边框厚度,该公式使得厚度可以根据图片的大小来自动调整
        thickness = (image.size[0] + image.size[1]) // 300  #
        # 遍历每个检测到的目标detection:(classname,probaility,(x,y,w,h))
        for c, detection in enumerate(detections):
            # 获取当前目标的类别和置信度分数
            classname = detection[0]
            # score = round(detection[1] * 100, 2)
            score = round(detection[1], 2)
            label = \'{} {:.2f}\'.format(classname, score)
            # 计算检测框左上角(xmin, ymin)和右下角的坐标(xmax, ymax)
            x, y, w, h = detection[2][0], \
                         detection[2][1], \
                         detection[2][2], \
                         detection[2][3]
            xmin, ymin, xmax, ymax = _convertBack(
                float(x), float(y), float(w), float(h))
            # 获取绘制实例
            draw = ImageDraw.Draw(image)
            # 获取将显示的文本的大小
            label_size = draw.textsize(label, font)
            # 将坐标对应到top, left, bottom, right,注意不要对应错了
            top, left, bottom, right = (ymin, xmin, ymax, xmax)
            top = max(0, np.floor(top + 0.5).astype(\'int32\'))
            left = max(0, np.floor(left + 0.5).astype(\'int32\'))
            bottom = min(image.size[1], np.floor(bottom + 0.5).astype(\'int32\'))
            right = min(image.size[0], np.floor(right + 0.5).astype(\'int32\'))
            print(label, (left, top), (right, bottom))
            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])
            if c > len(self.class_names) - 1:
                c = 1
            # 绘制边框厚度
            for i in range(thickness):
                draw.rectangle(
                    [left + i, top + i, right - i, bottom - i],
                    outline=self.colors[c])
            # 绘制检测框的文本边界
            draw.rectangle(
                [tuple(text_origin), tuple(text_origin + label_size)],
                fill=self.colors[c])
            # 绘制文本
            draw.text(text_origin, label, fill=(0, 0, 0), fOnt=font)
            del draw
        return image

    def detect_video(self, video_path, output_path="",show=True):
        nw = self.model_image_size[0]
        nh = self.model_image_size[1]
        assert nw % 32 == 0, \'Multiples of 32 required\'
        assert nh % 32 == 0, \'Multiples of 32 required\'
        vid = cv2.VideoCapture(video_path)
        if not vid.isOpened():
            raise IOError("Couldn\'t open webcam or video")
        video_FourCC = cv2.VideoWriter_fourcc(*"mp4v")
        video_fps = vid.get(cv2.CAP_PROP_FPS)
        video_size = (nw,nh)
        isOutput = True if output_path != "" else False
        if isOutput:
            print("!!! TYPE:", type(output_path), type(video_FourCC), type(video_fps), type(video_size))
            out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size)
        accum_time = 0
        curr_fps = 0
        fps = "FPS: ??"
        prev_time = timer()

        # Create an image we reuse for each detect
        darknet_image = darknet.make_image(nw, nh, 3)
        while True:
            return_value, frame = vid.read()
            if return_value:
                # 转成RGB格式,因为opencv默认使用BGR格式读取图片,而PIL是用RGB
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame_rgb)
                image_resized = image.resize(video_size, Image.LINEAR)
                darknet.copy_image_from_bytes(darknet_image, np.asarray(image_resized).tobytes())
                detections = darknet.detect_image(self.class_names, self.netMain, self.metaMain, darknet_image,
                                                  thresh=self.thresh, debug=True)
                image_resized = self.cvDrawBoxes(detections, image_resized)
                result = np.asarray(image_resized)
                # 转成BGR格式以便opencv处理
                result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
                curr_time = timer()
                exec_time = curr_time - prev_time
                prev_time = curr_time
                accum_time = accum_time + exec_time
                curr_fps = curr_fps + 1
                if accum_time > 1:
                    accum_time = accum_time - 1
                    fps = "FPS: " + str(curr_fps)
                    curr_fps = 0
                cv2.putText(result, text=fps, org=(3, 15), fOntFace=cv2.FONT_HERSHEY_SIMPLEX,
                            fontScale=0.50, color=(255, 0, 0), thickness=2)
                if show:
                    cv2.imshow("Object Detect", result)
                if isOutput:
                    print("start write...==========================================")
                    out.write(result)
                if cv2.waitKey(1) & 0xFF == ord(\'q\'):
                    break
            else:
                break
        out.release()
        vid.release()
        cv2.destroyAllWindows()

    def detect_image(self, image_path, save_path):
        nw = self.model_image_size[0]
        nh = self.model_image_size[1]
        assert nw % 32 == 0, \'Multiples of 32 required\'
        assert nh % 32 == 0, \'Multiples of 32 required\'
        try:
            image = Image.open(image_path)
        except:
            print(\'Open Error! Try again!\')
        else:
            image_resized = image.resize((nw, nh), Image.LINEAR)
            darknet_image = darknet.make_image(nw, nh, 3)
            darknet.copy_image_from_bytes(darknet_image, np.asarray(image_resized).tobytes())
            # 识别图片得到目标的类别、置信度、中心点坐标和检测框的高宽
            detectiOns= darknet.detect_image(self.class_names, self.netMain, self.metaMain, darknet_image,
                                              thresh=0.25, debug=True)
            # 在图片上将detections信息绘制出来
            image_resized = self.cvDrawBoxes(detections, image_resized)
            # 显示绘制后的图片
            image_resized.show()
            image_resized.save(save_path)

if __name__ == "__main__":
    _yolo = YOLO()
    _yolo.detect_image("names-data/images/food.JPG", "names-data/images/food_detect.JPG")
    # _yolo.detect_video("names-data/videos/food.mp4", "names-data/videos/food_detect.mp4",show=False)

  上面的代码的关键部分都附有相关的注释,这里就不一一解读了,另外附上中文字体文件,放到项目的font目录下即可。

  下载链接: https://github.com/Halfish/lstm-ctc-ocr/blob/master/fonts/simfang.ttf

下面是我收藏的一些其他字体,你可以挑选自己喜欢的字体使用。

链接:https://pan.baidu.com/s/1PWS7Hw1z3dkDyq7feZxqEQ
提取码:xu8q

  下面看看如何显示置信度,打开src/images.c文件,将draw_detections_cv_v3函数用如下代码替换,注意替换后要重新make一下项目:

void draw_detections_cv_v3(IplImage* show_img, detection *dets, int num, float thresh, char **names, image **alphabet, int classes, int ext_output){
    int i, j;
    if (!show_img) return;
    static int frame_id = 0;
    frame_id++;
    for (i = 0; i i) {
        char labelstr[4096] = { 0 };
        int class_id = -1;
        for (j = 0; j j) {
            int show = strncmp(names[j], "dont_show", 9);
            if (dets[i].prob[j] > thresh && show) {
                float score=dets[i].prob[j];//在label标签上加入置信度
                if (class_id <0) {
                    strcat(labelstr, names[j]);
                    strcat(labelstr, ", ");
                    sprintf(labelstr + strlen(labelstr), "%0.2f", score);
                    class_id = j;
                }
                else {
                    strcat(labelstr, ", ");
                    strcat(labelstr, names[j]);
                    strcat(labelstr, ", ");
                    sprintf(labelstr + strlen(labelstr), "%0.2f", score);
                }
                printf("%s: %.0f%% ", names[j], score * 100);
            }
        }
        if (class_id >= 0) {
            int width = show_img->height * .006;
            int offset = class_id * 123457 % classes;
            float red = get_color(2, offset, classes);
            float green = get_color(1, offset, classes);
            float blue = get_color(0, offset, classes);
            float rgb[3];
            rgb[0] = red;
            rgb[1] = green;
            rgb[2] = blue;
            box b = dets[i].bbox;
            b.w = (b.w <1) ? b.w : 1;
            b.h = (b.h <1) ? b.h : 1;
            b.x = (b.x <1) ? b.x : 1;
            b.y = (b.y <1) ? b.y : 1;
            int left = (b.x - b.w / 2.)*show_img->width;
            int right = (b.x + b.w / 2.)*show_img->width;
            int top = (b.y - b.h / 2.)*show_img->height;
            int bot = (b.y + b.h / 2.)*show_img->height;
            if (left <0) left = 0;
            if (right > show_img->width - 1) right = show_img->width - 1;
            if (top <0) top = 0;
            if (bot > show_img->height - 1) bot = show_img->height - 1;
            float const font_size = show_img->height / 1000.F;
            CvPoint pt1, pt2, pt_text, pt_text_bg1, pt_text_bg2;
            pt1.x = left;
            pt1.y = top;
            pt2.x = right;
            pt2.y = bot;
            pt_text.x = left;
            pt_text.y = top - 12;
            pt_text_bg1.x = left;
            pt_text_bg1.y = top - (10 + 25 * font_size);
            pt_text_bg2.x = right;
            pt_text_bg2.y = top;
            CvScalar color;
            color.val[0] = red * 256;
            color.val[1] = green * 256;
            color.val[2] = blue * 256;
            cvRectangle(show_img, pt1, pt2, color, width, 8, 0);
            if (ext_output)
                printf("\t(left_x: %4.0f   top_y: %4.0f   width: %4.0f   height: %4.0f)\n",
                    (float)left, (float)top, b.w*show_img->width, b.h*show_img->height);
            else
                printf("\n");
            cvRectangle(show_img, pt_text_bg1, pt_text_bg2, color, width, 8, 0);
            cvRectangle(show_img, pt_text_bg1, pt_text_bg2, color, CV_FILLED, 8, 0);    // filled
            CvScalar black_color;
            black_color.val[0] = 0;
            CvFont font;
            cvInitFont(&font, CV_FONT_HERSHEY_SIMPLEX, font_size, font_size, 0, font_size * 3, 8);
            cvPutText(show_img, labelstr, pt_text, &font, black_color);
        }
    }
    if (ext_output) {
        fflush(stdout);
    }
}

  以上操作都准备好了之后,执行python  darknet_video_custom.py即可开始检测图片或视频中的物体。效果如下:

  是不是很酷呢O(∩_∩)O~。本系列文章到此已经写了4篇,分别是《快速感受物体检测的酷炫》、《数据标注》、《模型训练》、《模型使用》,我们已经体验了整个物体检测的过程,对物体检测的过程有了一定的了解。下一篇《手把手教你用深度学习做物体检测(五):YOLO》会介绍一下YOLO算法的相关内容,让我们了解目标检测背后是如何工作的。

  ok,本篇就这么多内容啦~,感谢阅读O(∩_∩)O,88~

 
名句分享

 人不是向外奔走才是旅行,静静坐着思维也是旅行,凡是探索、追寻、触及那些不可知的情境,不论是风土的,或是心灵的,都是一种旅行。
—— 林清玄  
为您推荐
如何在阿里云租一台GPU服务器做深度学习?

手把手教你用深度学习做物体检测(三):模型训练

手把手教你用深度学习做物体检测(二):数据标注

手把手教你用深度学习做物体检测(一): 快速感受物体检测的酷炫

ubuntu16.04安装Anaconda3

Unbuntu下持续观察NvidiaGPU的状态

 
我的博客即将同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=1kvpuxzlylh68

推荐阅读
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 本文介绍如何使用Python进行文本处理,包括分词和生成词云图。通过整合多个文本文件、去除停用词并生成词云图,展示文本数据的可视化分析方法。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • 2023年京东Android面试真题解析与经验分享
    本文由一位拥有6年Android开发经验的工程师撰写,详细解析了京东面试中常见的技术问题。涵盖引用传递、Handler机制、ListView优化、多线程控制及ANR处理等核心知识点。 ... [详细]
  • 本文介绍如何使用 Python 获取文件和图片的创建、修改及拍摄日期。通过多种方法,如 PIL 库的 _getexif() 函数和 os 模块的 getmtime() 和 stat() 方法,详细讲解了这些技术的应用场景和注意事项。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • Linux 系统启动故障排除指南:MBR 和 GRUB 问题
    本文详细介绍了 Linux 系统启动过程中常见的 MBR 扇区和 GRUB 引导程序故障及其解决方案,涵盖从备份、模拟故障到恢复的具体步骤。 ... [详细]
  • 深入理解Tornado模板系统
    本文详细介绍了Tornado框架中模板系统的使用方法。Tornado自带的轻量级、高效且灵活的模板语言位于tornado.template模块,支持嵌入Python代码片段,帮助开发者快速构建动态网页。 ... [详细]
  • CentOS7源码编译安装MySQL5.6
    2019独角兽企业重金招聘Python工程师标准一、先在cmake官网下个最新的cmake源码包cmake官网:https:www.cmake.org如此时最新 ... [详细]
  • 本文介绍如何通过注册表编辑器自定义和优化Windows文件右键菜单,包括删除不需要的菜单项、添加绿色版或非安装版软件以及将特定应用程序(如Sublime Text)添加到右键菜单中。 ... [详细]
  • 本文介绍如何使用 Python 提取和替换 .docx 文件中的图片。.docx 文件本质上是压缩文件,通过解压可以访问其中的图片资源。此外,我们还将探讨使用第三方库 docx 的方法来简化这一过程。 ... [详细]
  • PHP 5.5.0rc1 发布:深入解析 Zend OPcache
    2013年5月9日,PHP官方发布了PHP 5.5.0rc1和PHP 5.4.15正式版,这两个版本均支持64位环境。本文将详细介绍Zend OPcache的功能及其在Windows环境下的配置与测试。 ... [详细]
author-avatar
Healthcen健康
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有