热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow2.0YOLOV4tiny网络原理及代码解析(三)损失函数的构建

Tensorflow2.0—YOLOV4-tiny网络原理及代码解析(三)-损失函数的构建YOLOV4中的损失函数与V3还是有比较大的区别的ÿ
Tensorflow2.0—YOLO V4-tiny网络原理及代码解析(三)- 损失函数的构建

YOLO V4中的损失函数与V3还是有比较大的区别的,具体的可以看YOLOV4与YOLOV3的区别。
代码是在nets文件夹下面的loss.py文件中,在train.py中引用的是:

model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5, 'label_smoothing': label_smoothing})(loss_input)

先来看下yolo_loss的参数:

def yolo_loss(args, #是一个列表,其中包含了预测结果和进行编码之后的真实框的结果,# [, 预测结果1# , 预测结果2# , 真实框1# ] 真实框2anchors, #[[ 10. 14.]# [ 23. 27.]# [ 37. 58.]# [ 81. 82.]# [135. 169.]# [344. 319.]]num_classes, #['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],共20个分类ignore_thresh=.5, #阈值label_smoothing=0.1, #标签平滑print_loss=False, #是否打印损失normalize=True): #是否做归一化

把预测的与真实的进行分割出来,
y_true是一个列表,包含两个特征层,shape分别为(m,13,13,75),(m,26,26,75)
yolo_outputs是一个列表,包含两个特征层,shape分别为(m,13,13,3,25),(m,26,26,3,25)

y_true = args[num_layers:]
yolo_outputs = args[:num_layers]

然后就开始以特征层数开始循环,这里就以(m,13,13,3,25)为例:
先获得真实框(m,13,13,3,25)的第5个位置的数据,如果在编码中该位置存在gt框,那么就设置为1,表示第(i,j)特征图中第k个锚点框包含物体,否则一切都设置为0

object_mask = y_true[l][..., 4:5]

然后,再获得真实框(m,13,13,3,25)的第6-26个位置的数据:

true_class_probs = y_true[l][..., 5:]

接着,进行标签平滑:

if label_smoothing:true_class_probs = _smooth_labels(true_class_probs, label_smoothing)

def _smooth_labels(y_true, label_smoothing):num_classes = tf.cast(K.shape(y_true)[-1], dtype=K.floatx())label_smoothing = K.constant(label_smoothing, dtype=K.floatx())return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes


接着,就是yolo_head:

yolo_head(yolo_outputs[l], #yolo模型预测的结果,shape=(None,13,13,75)anchors[anchor_mask[l]], #预测结果所对应的anchor,这里为#[[ 81. ,82.],# [135. ,169.],# [344. ,319.]]num_classes, #voc数据集的class数量,为20input_shape, #(416,416)calc_loss=True) #是否计算损失函数,在模型预测阶段,该参数为False

下面这一段代码有点难理解,其实它的作用就是创建(13,13,1,2)的网格

grid_shape = K.shape(feats)[1:3] # height, width,为(13,13)
grid_y = K.tile(K.reshape(K.arange(0, stop=grid_shape[0]), [-1, 1, 1, 1]),[1, grid_shape[1], 1, 1])
grid_x = K.tile(K.reshape(K.arange(0, stop=grid_shape[1]), [1, -1, 1, 1]),[grid_shape[0], 1, 1, 1])
grid = K.concatenate([grid_x, grid_y])
grid = K.cast(grid, K.dtype(feats))

举个例子:
在这里插入图片描述
在这幅图片中就可以很好的看到最终其实就是生成了(13,13,1,2)的网格。

feats = K.reshape(feats, [-1, grid_shape[0], grid_shape[1], num_anchors, num_classes + 5])

将预测结果(m,13,13,75)分割成(m,13,13,3,25)。

box_xy = (K.sigmoid(feats[..., :2]) + grid) / K.cast(grid_shape[...,::-1], K.dtype(feats))
box_wh = K.exp(feats[..., 2:4]) * anchors_tensor / K.cast(input_shape[...,::-1], K.dtype(feats))
box_confidence = K.sigmoid(feats[..., 4:5])
box_class_probs = K.sigmoid(feats[..., 5:])

这四行代码的目的是:将预测值转换为真实值!
第一行:将预测结果的xy(None,13,13,3,2)先加上grid,然后除以(13,13),就得到了转化后的进行归一化的xy,shape=(None,13,13,3,2)。
第二行:将预测结果的wh(None,13,13,3,2)先乘anchor的尺寸,然后除以输入大小,最后再进行指数计算,得到转换后的wh,shape同xy。
第三行和第四行:将预测得到的confidence和class_prob进行sigmoid转化。

最后,就可以将预测结果进行decode为预测真实框的形式~

if calc_loss == True:return grid, feats, box_xy, box_whreturn box_xy, box_wh, box_confidence, box_class_probs

在计算loss的时候返回grid, feats, box_xy, box_wh
在预测的时候返回box_xy, box_wh, box_confidence, box_class_probs

pred_box = K.concatenate([pred_xy, pred_wh])

将上述box_xy, box_wh进行合并,shape=(None,13,13,3,4)


下面,还有一个loop_body函数,来看看它是干嘛的。

#-----------------------------------------------------------## 对每一张图片计算ignore_mask#-----------------------------------------------------------#def loop_body(b, ignore_mask):#-----------------------------------------------------------## 取出n个真实框&#xff1a;n,4#-----------------------------------------------------------#true_box &#61; tf.boolean_mask(y_true[l][b,...,0:4], object_mask_bool[b,...,0])#-----------------------------------------------------------## 计算预测框与真实框的iou# pred_box 13,13,3,4 预测框的坐标# true_box n,4 真实框的坐标# iou 13,13,3,n 预测框和真实框的iou#-----------------------------------------------------------#iou &#61; box_iou(pred_box[b], true_box)#-----------------------------------------------------------## best_iou 13,13,3 每个特征点与真实框的最大重合程度#-----------------------------------------------------------#best_iou &#61; K.max(iou, axis&#61;-1)#-----------------------------------------------------------## 判断预测框和真实框的最大iou小于ignore_thresh# 则认为该预测框没有与之对应的真实框# 该操作的目的是&#xff1a;# 忽略预测结果与真实框非常对应特征点&#xff0c;因为这些框已经比较准了# 不适合当作负样本&#xff0c;所以忽略掉。#-----------------------------------------------------------#ignore_mask &#61; ignore_mask.write(b, K.cast(best_iou<ignore_thresh, K.dtype(true_box)))return b&#43;1, ignore_mask#-----------------------------------------------------------## 在这个地方进行一个循环、循环是对每一张图片进行的#-----------------------------------------------------------#_, ignore_mask &#61; tf.while_loop(lambda b,*args: b<m, loop_body, [0, ignore_mask])

步骤解释&#xff1a;
1.先取出真实框存在物体的框的xywh
2.与预测框进行iou计算
3.找到对应每个真实框最大的iou的预测框&#xff0c;best_iou &#61; &#xff08;13&#xff0c;13&#xff0c;3&#xff09;
4.判断预测框和真实框的最大iou小于ignore_thresh&#xff0c;则认为该预测框没有与之对应的真实框&#xff0c;这么做的目的是&#xff1a;忽略预测结果与真实框非常对应特征点&#xff0c;因为这些框已经比较准了&#xff0c;不适合当作负样本&#xff0c;所以忽略掉
5.把这些框放入ignore_mask中

#-----------------------------------------------------------## 真实框越大&#xff0c;比重越小&#xff0c;小框的比重更大。#-----------------------------------------------------------#box_loss_scale &#61; 2 - y_true[l][...,2:3]*y_true[l][...,3:4]#-----------------------------------------------------------## 计算Ciou loss#-----------------------------------------------------------#raw_true_box &#61; y_true[l][...,0:4]ciou &#61; box_ciou(pred_box, raw_true_box)ciou_loss &#61; object_mask * box_loss_scale * (1 - ciou)

这一部分是进行预测与真实的ciou损失&#xff01;但是代码实现中与论文中还是有一点区别的&#xff0c;在代码中还考虑到了真实框大小的因素&#xff1a;真实框越大&#xff0c;比重越小&#xff0c;小框的比重更大。&#xff08;这里我就不写ciou的代码了&#xff0c;有机会单独写一个blog~&#xff09;&#xff0c;得到的ciou_loss &#61; &#xff08;None,13,13,3,1&#xff09;。

最后&#xff0c;就是置信度损失和类别损失的计算了&#xff1a;

confidence_loss &#61; object_mask * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits&#61;True)&#43; \(1-object_mask) * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits&#61;True) * ignore_maskclass_loss &#61; object_mask * K.binary_crossentropy(true_class_probs, raw_pred[...,5:], from_logits&#61;True)

第一行&#xff1a;先计算真实框存在的confidence_loss&#xff0c;加上不存在真实框的confidence_loss&#xff08;这里要忽略那些ignore_mask里面的框&#xff09;。
第二行&#xff1a;直接计算预测和真实的类别损失。

num_pos &#43;&#61; tf.maximum(K.sum(K.cast(object_mask, tf.float32)), 1)loss &#43;&#61; location_loss &#43; confidence_loss &#43; class_lossloss &#61; K.expand_dims(loss, axis&#61;-1)if normalize:loss &#61; loss / num_poselse:loss &#61; loss / mfreturn loss

最后就是进行损失求和&#xff0c;若进行归一化&#xff0c;就将总损失除以正样本&#xff0c;要是不进行归一化&#xff0c;那就除以批次大小。最终得到最后的loss~


推荐阅读
  • 使用Python构建网页版图像编辑器
    本文详细介绍了一款基于Python开发的网页版图像编辑工具,具备多种图像处理功能,如黑白转换、铅笔素描效果等。 ... [详细]
  • 酷家乐 Serverless FaaS 产品实践探索
    本文探讨了酷家乐在 Serverless FaaS 领域的实践与经验,重点介绍了 FaaS 平台的构建、业务收益及未来发展方向。 ... [详细]
  • 字符、字符串和文本的处理之Char类型
    .NetFramework中处理字符和字符串的主要有以下这么几个类:(1)、System.Char类一基础字符串处理类(2)、System.String类一处理不可变的字符串(一经 ... [详细]
  • 本文介绍了多种Eclipse插件,包括XML Schema Infoset Model (XSD)、Graphical Editing Framework (GEF)、Eclipse Modeling Framework (EMF)等,涵盖了从Web开发到图形界面编辑的多个方面。 ... [详细]
  • Nagios可视化插件开发指南 —— 配置详解
    本文详细介绍了Nagios监控系统的配置过程,包括数据库的选择与安装、Nagios插件的安装及配置文件的解析。同时,针对常见的配置错误提供了具体的解决方法。 ... [详细]
  • 本教程旨在指导开发者如何在Android应用中通过ViewPager组件实现图片轮播功能,适用于初学者和有一定经验的开发者,帮助提升应用的视觉吸引力。 ... [详细]
  • 深入解析Java并发之ArrayBlockingQueue
    本文详细探讨了ArrayBlockingQueue,这是一种基于数组实现的阻塞队列。ArrayBlockingQueue在初始化时需要指定容量,因此它是一个有界的阻塞队列。文章不仅介绍了其基本概念和数据结构,还深入分析了其源码实现,包括各种入队、出队、获取元素和删除元素的方法。 ... [详细]
  • 使用R语言进行Foodmart数据的关联规则分析与可视化
    本文探讨了如何利用R语言中的arules和arulesViz包对Foodmart数据集进行关联规则的挖掘与可视化。文章首先介绍了数据集的基本情况,然后逐步展示了如何进行数据预处理、规则挖掘及结果的图形化呈现。 ... [详细]
  • 本文总结了在多人协作开发环境中使用 Git 时常见的问题及其解决方案,包括错误合并分支的处理、使用 SourceTree 查找问题提交、Git 自动生成的提交信息解释、删除远程仓库文件夹而不删除本地文件的方法、合并冲突时的注意事项以及如何将多个提交合并为一个。 ... [详细]
  • 本文介绍了如何使用 Python 的 Pyglet 库加载并显示图像。Pyglet 是一个用于开发图形用户界面应用的强大工具,特别适用于游戏和多媒体项目。 ... [详细]
  • 深入解析Unity3D游戏开发中的音频播放技术
    在游戏开发中,音频播放是提升玩家沉浸感的关键因素之一。本文将探讨如何在Unity3D中高效地管理和播放不同类型的游戏音频,包括背景音乐和效果音效,并介绍实现这些功能的具体步骤。 ... [详细]
  • SQLite是一种轻量级的关系型数据库管理系统,尽管体积小巧,却能支持高达2TB的数据库容量,每个数据库以单个文件形式存储。本文将详细介绍SQLite在Android开发中的应用,包括其数据存储机制、事务处理方式及数据类型的动态特性。 ... [详细]
  • 基于OpenCV的小型图像检索系统开发指南
    本文详细介绍了如何利用OpenCV构建一个高效的小型图像检索系统,涵盖从图像特征提取、视觉词汇表构建到图像数据库创建及在线检索的全过程。 ... [详细]
  • 我在尝试将组合框转换为具有自动完成功能时遇到了一个问题,即页面上的列表框也被转换成了自动完成下拉框,而不是保持原有的多选列表框形式。 ... [详细]
  • 汇编语言标识符和表达式(四)(表达式与符号定义语句)
    7、表达式表达式是程序设计课程里的一个重要的基本概念,它可由运算符、操作符、括号、常量和一些符号连在一起的式子。在汇编语言中,表达式分为:数值表达式和地址表达式。(1)进制伪指令R ... [详细]
author-avatar
zackcoolgirl_497
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有