热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

ssd(SingleShotMultiBoxDetector)代码解读之(三)multiboxloss损失函数

multiboxloss是SSD的损失函数跟交叉熵那些官方给出接口的损失函数不同。multiboxloss需要自己定义的。建议:边看代码边看此博客。代码来源&#

multibox loss 是SSD的损失函数

跟交叉熵那些官方 给出接口的损失函数不同。multibox loss需要自己定义的。

建议:边看代码边看此博客。

代码来源:https://github.com/amdegroot/ssd.pytorch

multibox loss的流程:

1.获取ssd网络的输出

ssd网络输出有三个,其中两个是预测值,一个是所有default box的集合。预测值一个是回归值(框的位置信息),一个是分类值(框的类别信息)。

2.match过程

match过程主要做这几件事:

(1)把所有default box与GroundTrue框做一个杰卡德相似计算(与IOU相似),得到一个矩阵 [GroundTrue数,default box数] ,里面的值就是它们对应的杰卡德相似值。

(2)分别计算各个GroundTrue对应的default box,各个default box对应的GroundTrue框。得到[GroundTrue数,1]和[1,default box数]。得到四个矩阵如下:

best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) #[1,num_priors]best_truth_overlap.squeeze_(0) #[1,num_priors]best_prior_idx.squeeze_(1) #[num_objs,1]best_prior_overlap.squeeze_(1) #[num_objs,1]best_truth_overlap.index_fill_(0, best_prior_idx, 2) # 保证最best的default box的杰卡德系数不因为太小而被下面的代码过滤掉

(3)获得default box对应的GroundTrue的回归值和分类值

上面(2)得到了best_truth_idx矩阵,因为每个default box知道了自己对应的是哪一个GroundTrue框,因此我们可以让default box知道自己对应的GroundTrue框的回归值和分类值:

(4)过滤杰卡德相似系数少于阈值(=0.5)的default  box,并把它的类别设置成0,即设置成背景类。(处理conf矩阵),得到的矩阵命名为 conf_t

(5)encode过程

encode过程是算出 default box与其对应的GroundTrue的回归偏置值。偏置值包括中心坐标偏置和边长偏置。俗话说就是得到default box与其对应的GroundTrue的变化值,default box是怎么变才能变换到对应的GroundTrue上去。当然有encode就有decode,decode是把偏置值加上其对应的default box 得到预测的框。但训练过程是不需要decode的,测试或者要看预测框的时候才需要用到decode。encode处理后的矩阵命名为 loc_t

所以综上所述,match过程就是返回 loc_t 和 conf_t 矩阵。

代码:

def encode(matched, priors, variances): #获取扩张/平移系数"""Encode the variances from the priorbox layers into the ground truth boxeswe have matched (based on jaccard overlap) with the prior boxes.Args:matched: (tensor) Coords of ground truth for each prior in point-formShape: [num_priors, 4].priors: (tensor) Prior boxes in center-offset formShape: [num_priors,4].variances: (list[float]) Variances of priorboxesReturn:encoded boxes (tensor), Shape: [num_priors, 4]"""
#xmin, ymin, xmax, ymax# dist b/t match center and prior's center 计算match的box和priors的box的中间点的距离g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] #matched的中间点 - priors的中间点-# encode varianceg_cxcy /= (variances[0] * priors[:, 2:]) #除长宽,得出priors的框和GT框的相对移动距离# match wh / prior whg_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] #先得出match的长宽,再除框的长宽g_wh = torch.log(g_wh) / variances[1]# return target for smooth_l1_lossreturn torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 得出priors自身和对应的GT框的变化

3.寻找正样本和负样本

(1)正样本:

把conf_t矩阵中大于0的值置1,等于0的值置零。即把背景框置0,非背景框置1.把结果放到pos矩阵中:

得到 pos矩阵后,就知道了default box中哪些框是属于背景,哪些框不是背景。知道那些不是背景的框后,就把这些框的回归值拿出来作为新的loc_t矩阵。

设 SSD网络输出的回归值名为 loc_data矩阵,loc_data矩阵记录着8732个预测框的偏置值。

代码:

pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) #在最后加一维度batchsize,num_priors,4]
loc_p = loc_data[pos_idx].view(-1, 4) #loc_p 是网络预测出来的
loc_t = loc_t[pos_idx].view(-1, 4) #loc_t 是overlap>阈值的 prioris和对应GT框偏置值
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) #预测偏置值

其中 新的loc_t矩阵中的框就是 正样本的框。

(2)负样本

得到正样本后,我们就开始寻找负样本了。

首先,计算 default box 中所有框的分类损失,得到 loss_c矩阵:

然后,过滤掉 loss_c为正样本的框,留下负样本的框:

代码:

batch_conf = conf_data.view(-1, self.num_classes)
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
loss_c = loss_c.view(num, -1) #一个batchsize每张图的 loss_c
loss_c[pos] = 0 # 剩下分类为背景的框的损失值

由上图得到loss_c矩阵后,就知道了default box中属于背景的框的损失值,然后给loss_c的损失值进行排名,例如1:0.25,2:0.51,则表示第一个框属于背景的损失是0.25,第二个框属于背景的损失是0.51,那么进行排名后:1:2,2:1表示第一个框的损失是在整个矩阵中排第二的,第二个框则是排第一,因此就得到了 loss_rank矩阵,最后得到记录负样本的neg矩阵,如下图:

代码:

_, loss_idx = loss_c.sort(1, descending=True)[batch,num_priors]
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank

4.计算分类损失

得到正样本和负样本后,我们要怎么运用起来呢?

首先我们要知道SSD网络输出的分类预测是矩阵conf_data如下:

conf_data记录的是每个default box对应的 num_class个分类的预测值。num_class个值中最大的那个就作为此框的类别。

由于 pos矩阵和neg矩阵中的值都是1或者0。因此pos矩阵相当于记录了default box中正样本框的序号,同理neg矩阵记录的就是负样本框的序号。

我们要明白一件事就是 pos矩阵中为1的肯定是正样本,但并非说为0的就一定是负样本。neg矩阵中为1的肯定是负样本,也并非说为0的就一定是正样本,因为由于Hard Negative Mining的关系,负样本的数目被设置成了是正样本的三倍,所以这就导致了并非所有pos中为0的框都是负样本。

然后使用一个技巧,所有default box中 pos 和 neg中的值加起来大于0的框,才作为真正的训练样本 conf_p。意思就是conf_p中的框不是正样本就是负样本。意在去除pos=0且neg=0的框,这类框的特点是,与GroundTrue不怎么沾边,被分类成背景框,且作为背景框的损失值还很小,就是说这类框必是背景框了。要作为训练数据的框,要么是与GroundTrue交叠比较大的正样本,要么是损失值大的背景框,就是说它看上去是背景,但好像又不是背景的框。

最后把 conf_t和targets_weighted输入到交叉熵损失函数就可以得到最终的分类损失值。

得到回归损失值和分类损失值后,就加起来就可以计算出最终的 multibox 损失函数值了。

代码:

pos_idx = pos.unsqueeze(2).expand_as(conf_data) #[batchsize,num_priors,num_class]为1的是大于阈值的框neg_idx = neg.unsqueeze(2).expand_as(conf_data) #[batchsize,num_priors,num_class]为1的是负样本conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)#pos=0是小于阈值的框,neg=0是损失很小的框#pos_idx+neg_idx大于0的数据保留,其余舍去targets_weighted = conf_t[(pos+neg).gt(0)]loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

multibox loss完整代码:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from data import coco as cfg
from ..box_utils import match, log_sum_expclass MultiBoxLoss(nn.Module):"""SSD Weighted Loss FunctionCompute Targets:1) Produce Confidence Target Indices by matching ground truth boxeswith (default) 'priorboxes' that have jaccard index > threshold parameter(default threshold: 0.5).2) Produce localization target by 'encoding' variance into offsets of groundtruth boxes and their matched 'priorboxes'.3) Hard negative mining to filter the excessive number of negative examplesthat comes with using a large number of default bounding boxes.(default negative:positive ratio 3:1)Objective Loss:L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / NWhere, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Lossweighted by α which is set to 1 by cross val.Args:c: class confidences,l: predicted boxes,g: ground truth boxesN: number of matched default boxesSee: https://arxiv.org/pdf/1512.02325.pdf for more details."""def __init__(self, num_classes, overlap_thresh, prior_for_matching,bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,use_gpu=True):super(MultiBoxLoss, self).__init__()self.use_gpu = use_gpuself.num_classes = num_classesself.threshold = overlap_threshself.background_label = bkg_labelself.encode_target = encode_targetself.use_prior_for_matching = prior_for_matchingself.do_neg_mining = neg_miningself.negpos_ratio = neg_posself.neg_overlap = neg_overlapself.variance = cfg['variance']def forward(self, predictions, targets):''' output = ( --->predictionsloc.view(loc.size(0), -1, 4),conf.view(conf.size(0), -1, self.num_classes),self.priors'''"""Multibox LossArgs:predictions (tuple): A tuple containing loc preds, conf preds,and prior boxes from SSD net.conf shape: torch.size(batch_size,num_priors,num_classes)loc shape: torch.size(batch_size,num_priors,4)priors shape: torch.size(num_priors,4)targets (tensor): Ground truth boxes and labels for a batch,shape: [batch_size,num_objs,5] (last idx is the label)."""loc_data, conf_data, priors = predictions #loc是ssd生成的偏置值,priors是PriorBox方法画出来的boxnum = loc_data.size(0) #num = batch_size loc_data=[batchsize,8732,4] priors=[8732xbatchsize,4]priors = priors[:loc_data.size(1), :] #[num_priors,4] 逐个batchsize处理 [8732,4]num_priors = (priors.size(0)) #num_proirs,一个batchsize所有框num_classes = self.num_classes# match priors (default boxes) and ground truth boxesloc_t = torch.Tensor(num, num_priors, 4) #num=batchsizeconf_t = torch.LongTensor(num, num_priors)for idx in range(num): #num = batch_size 一张张图片拿出来truths = targets[idx][:, :-1].data #target 有5个整数,前四个是坐标,框长宽,最后一个是类别labels = targets[idx][:, -1].data #[num_objs,1]defaults = priors.data #[num_priors,4]match(self.threshold, truths, defaults, self.variance, labels, #truths:[num_objs,4]loc_t, conf_t, idx)if self.use_gpu:loc_t = loc_t.cuda()conf_t = conf_t.cuda() #conf_t维度原为[num_priors],通过阈值处理后,变成[ 0 #pos为 conf_t大于0的部分,pos维度[batchsize,num_priors] [0,1,1,1,0,1.....]num_pos = pos.sum(dim=1, keepdim=True) #不因jard被舍去的框的个数# Localization Loss (Smooth L1)# Shape: [batch,num_priors,4]pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) #在最后加一维度batchsize,num_priors,4]#下面两步,1是从预测中舍去不通过jard的框,2是从 match中舍去不通过jard的框loc_p = loc_data[pos_idx].view(-1, 4) #loc_p 是网络预测出来的loc_t = loc_t[pos_idx].view(-1, 4) #loc_t 是overlap>阈值的 prioris和对应GT框偏置值loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) #预测偏置值#loc_t 相当于 真正训练的label,loc_p则相当于真正训练的预测输入# Compute max conf across batch for hard negative miningbatch_conf = conf_data.view(-1, self.num_classes) #[batch_size x num_priors,num_classes]batch_conf包含全部框#conf_t 是每个num_priors 框最符合的类别label 维度[num_priors]#con_t是不少于阈值的每个prior对应的类别(最大overlay重叠的GT框的类别)#loss_c是每个batchsize中每个priors框的类别 和 预测出的类别的损失loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) #conf_t=(batchsize, num_priors)#上式中conf_t --->[batchsize X num_priors,1]#后一项是选出大于阈值的框 的分类#loss_c 是各个框分类的损失值# Hard Negative Miningloss_c = loss_c.view(num, -1) #一个batchsize每张图的 loss_closs_c[pos] = 0 # filter out pos boxes for now 剩下分类为背景的框的损失值_, loss_idx = loss_c.sort(1, descending=True) #整个batchsize的loss_c排序_, idx_rank = loss_idx.sort(1)#各个框loss_c(分类损失)的排名,从大到小 [batch,num_priors]num_pos = pos.long().sum(1, keepdim=True) # #不因jard被舍去的框的个数 #[batch,1]num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) ##[batch,1]neg = idx_rank

 


推荐阅读
  • 用Vue实现的Demo商品管理效果图及实现代码
    本文介绍了一个使用Vue实现的Demo商品管理的效果图及实现代码。 ... [详细]
  • vue使用
    关键词: ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 本文介绍了使用kotlin实现动画效果的方法,包括上下移动、放大缩小、旋转等功能。通过代码示例演示了如何使用ObjectAnimator和AnimatorSet来实现动画效果,并提供了实现抖动效果的代码。同时还介绍了如何使用translationY和translationX来实现上下和左右移动的效果。最后还提供了一个anim_small.xml文件的代码示例,可以用来实现放大缩小的效果。 ... [详细]
  • 基于layUI的图片上传前预览功能的2种实现方式
    本文介绍了基于layUI的图片上传前预览功能的两种实现方式:一种是使用blob+FileReader,另一种是使用layUI自带的参数。通过选择文件后点击文件名,在页面中间弹窗内预览图片。其中,layUI自带的参数实现了图片预览功能。该功能依赖于layUI的上传模块,并使用了blob和FileReader来读取本地文件并获取图像的base64编码。点击文件名时会执行See()函数。摘要长度为169字。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • 如何使用Java获取服务器硬件信息和磁盘负载率
    本文介绍了使用Java编程语言获取服务器硬件信息和磁盘负载率的方法。首先在远程服务器上搭建一个支持服务端语言的HTTP服务,并获取服务器的磁盘信息,并将结果输出。然后在本地使用JS编写一个AJAX脚本,远程请求服务端的程序,得到结果并展示给用户。其中还介绍了如何提取硬盘序列号的方法。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文讨论了如何在codeigniter中识别来自angularjs的请求,并提供了两种方法的代码示例。作者尝试了$this->input->is_ajax_request()和自定义函数is_ajax(),但都没有成功。最后,作者展示了一个ajax请求的示例代码。 ... [详细]
  • 移动端常用单位——rem的使用方法和注意事项
    本文介绍了移动端常用的单位rem的使用方法和注意事项,包括px、%、em、vw、vh等其他常用单位的比较。同时还介绍了如何通过JS获取视口宽度并动态调整rem的值,以适应不同设备的屏幕大小。此外,还提到了rem目前在移动端的主流地位。 ... [详细]
  • 如何在HTML中获取鼠标的当前位置
    本文介绍了在HTML中获取鼠标当前位置的三种方法,分别是相对于屏幕的位置、相对于窗口的位置以及考虑了页面滚动因素的位置。通过这些方法可以准确获取鼠标的坐标信息。 ... [详细]
  • express工程中的json调用方法
    本文介绍了在express工程中如何调用json数据,包括建立app.js文件、创建数据接口以及获取全部数据和typeid为1的数据的方法。 ... [详细]
author-avatar
王德筠宛吟玟华
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有