multibox loss 是SSD的损失函数
跟交叉熵那些官方 给出接口的损失函数不同。multibox loss需要自己定义的。
建议:边看代码边看此博客。
代码来源:https://github.com/amdegroot/ssd.pytorch
multibox loss的流程:ssd网络输出有三个,其中两个是预测值,一个是所有default box的集合。预测值一个是回归值(框的位置信息),一个是分类值(框的类别信息)。
match过程主要做这几件事:
(1)把所有default box与GroundTrue框做一个杰卡德相似计算(与IOU相似),得到一个矩阵 [GroundTrue数,default box数] ,里面的值就是它们对应的杰卡德相似值。
(2)分别计算各个GroundTrue对应的default box,各个default box对应的GroundTrue框。得到[GroundTrue数,1]和[1,default box数]。得到四个矩阵如下:
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) #[1,num_priors]best_truth_overlap.squeeze_(0) #[1,num_priors]best_prior_idx.squeeze_(1) #[num_objs,1]best_prior_overlap.squeeze_(1) #[num_objs,1]best_truth_overlap.index_fill_(0, best_prior_idx, 2) # 保证最best的default box的杰卡德系数不因为太小而被下面的代码过滤掉
(3)获得default box对应的GroundTrue的回归值和分类值
上面(2)得到了best_truth_idx矩阵,因为每个default box知道了自己对应的是哪一个GroundTrue框,因此我们可以让default box知道自己对应的GroundTrue框的回归值和分类值:
(4)过滤杰卡德相似系数少于阈值(=0.5)的default box,并把它的类别设置成0,即设置成背景类。(处理conf矩阵),得到的矩阵命名为 conf_t
(5)encode过程
encode过程是算出 default box与其对应的GroundTrue的回归偏置值。偏置值包括中心坐标偏置和边长偏置。俗话说就是得到default box与其对应的GroundTrue的变化值,default box是怎么变才能变换到对应的GroundTrue上去。当然有encode就有decode,decode是把偏置值加上其对应的default box 得到预测的框。但训练过程是不需要decode的,测试或者要看预测框的时候才需要用到decode。encode处理后的矩阵命名为 loc_t
所以综上所述,match过程就是返回 loc_t 和 conf_t 矩阵。
代码:
def encode(matched, priors, variances): #获取扩张/平移系数"""Encode the variances from the priorbox layers into the ground truth boxeswe have matched (based on jaccard overlap) with the prior boxes.Args:matched: (tensor) Coords of ground truth for each prior in point-formShape: [num_priors, 4].priors: (tensor) Prior boxes in center-offset formShape: [num_priors,4].variances: (list[float]) Variances of priorboxesReturn:encoded boxes (tensor), Shape: [num_priors, 4]"""
#xmin, ymin, xmax, ymax# dist b/t match center and prior's center 计算match的box和priors的box的中间点的距离g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] #matched的中间点 - priors的中间点-# encode varianceg_cxcy /= (variances[0] * priors[:, 2:]) #除长宽,得出priors的框和GT框的相对移动距离# match wh / prior whg_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] #先得出match的长宽,再除框的长宽g_wh = torch.log(g_wh) / variances[1]# return target for smooth_l1_lossreturn torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 得出priors自身和对应的GT框的变化
(1)正样本:
把conf_t矩阵中大于0的值置1,等于0的值置零。即把背景框置0,非背景框置1.把结果放到pos矩阵中:
得到 pos矩阵后,就知道了default box中哪些框是属于背景,哪些框不是背景。知道那些不是背景的框后,就把这些框的回归值拿出来作为新的loc_t矩阵。
设 SSD网络输出的回归值名为 loc_data矩阵,loc_data矩阵记录着8732个预测框的偏置值。
代码:
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) #在最后加一维度batchsize,num_priors,4]
loc_p = loc_data[pos_idx].view(-1, 4) #loc_p 是网络预测出来的
loc_t = loc_t[pos_idx].view(-1, 4) #loc_t 是overlap>阈值的 prioris和对应GT框偏置值
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) #预测偏置值
其中 新的loc_t矩阵中的框就是 正样本的框。
(2)负样本
得到正样本后,我们就开始寻找负样本了。
首先,计算 default box 中所有框的分类损失,得到 loss_c矩阵:
然后,过滤掉 loss_c为正样本的框,留下负样本的框:
代码:
batch_conf = conf_data.view(-1, self.num_classes)
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
loss_c = loss_c.view(num, -1) #一个batchsize每张图的 loss_c
loss_c[pos] = 0 # 剩下分类为背景的框的损失值
由上图得到loss_c矩阵后,就知道了default box中属于背景的框的损失值,然后给loss_c的损失值进行排名,例如1:0.25,2:0.51,则表示第一个框属于背景的损失是0.25,第二个框属于背景的损失是0.51,那么进行排名后:1:2,2:1表示第一个框的损失是在整个矩阵中排第二的,第二个框则是排第一,因此就得到了 loss_rank矩阵,最后得到记录负样本的neg矩阵,如下图:
代码:
得到正样本和负样本后,我们要怎么运用起来呢? 首先我们要知道SSD网络输出的分类预测是矩阵conf_data如下: conf_data记录的是每个default box对应的 num_class个分类的预测值。num_class个值中最大的那个就作为此框的类别。 由于 pos矩阵和neg矩阵中的值都是1或者0。因此pos矩阵相当于记录了default box中正样本框的序号,同理neg矩阵记录的就是负样本框的序号。 我们要明白一件事就是 pos矩阵中为1的肯定是正样本,但并非说为0的就一定是负样本。neg矩阵中为1的肯定是负样本,也并非说为0的就一定是正样本,因为由于Hard Negative Mining的关系,负样本的数目被设置成了是正样本的三倍,所以这就导致了并非所有pos中为0的框都是负样本。 然后使用一个技巧,所有default box中 pos 和 neg中的值加起来大于0的框,才作为真正的训练样本 conf_p。意思就是conf_p中的框不是正样本就是负样本。意在去除pos=0且neg=0的框,这类框的特点是,与GroundTrue不怎么沾边,被分类成背景框,且作为背景框的损失值还很小,就是说这类框必是背景框了。要作为训练数据的框,要么是与GroundTrue交叠比较大的正样本,要么是损失值大的背景框,就是说它看上去是背景,但好像又不是背景的框。 最后把 conf_t和targets_weighted输入到交叉熵损失函数就可以得到最终的分类损失值。 得到回归损失值和分类损失值后,就加起来就可以计算出最终的 multibox 损失函数值了。 代码: _, loss_idx = loss_c.sort(1, descending=True)[batch,num_priors]
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank 4.计算分类损失
pos_idx = pos.unsqueeze(2).expand_as(conf_data) #[batchsize,num_priors,num_class]为1的是大于阈值的框neg_idx = neg.unsqueeze(2).expand_as(conf_data) #[batchsize,num_priors,num_class]为1的是负样本conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)#pos=0是小于阈值的框,neg=0是损失很小的框#pos_idx+neg_idx大于0的数据保留,其余舍去targets_weighted = conf_t[(pos+neg).gt(0)]loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
multibox loss完整代码:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from data import coco as cfg
from ..box_utils import match, log_sum_expclass MultiBoxLoss(nn.Module):"""SSD Weighted Loss FunctionCompute Targets:1) Produce Confidence Target Indices by matching ground truth boxeswith (default) 'priorboxes' that have jaccard index > threshold parameter(default threshold: 0.5).2) Produce localization target by 'encoding' variance into offsets of groundtruth boxes and their matched 'priorboxes'.3) Hard negative mining to filter the excessive number of negative examplesthat comes with using a large number of default bounding boxes.(default negative:positive ratio 3:1)Objective Loss:L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / NWhere, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Lossweighted by α which is set to 1 by cross val.Args:c: class confidences,l: predicted boxes,g: ground truth boxesN: number of matched default boxesSee: https://arxiv.org/pdf/1512.02325.pdf for more details."""def __init__(self, num_classes, overlap_thresh, prior_for_matching,bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,use_gpu=True):super(MultiBoxLoss, self).__init__()self.use_gpu = use_gpuself.num_classes = num_classesself.threshold = overlap_threshself.background_label = bkg_labelself.encode_target = encode_targetself.use_prior_for_matching = prior_for_matchingself.do_neg_mining = neg_miningself.negpos_ratio = neg_posself.neg_overlap = neg_overlapself.variance = cfg['variance']def forward(self, predictions, targets):''' output = ( --->predictionsloc.view(loc.size(0), -1, 4),conf.view(conf.size(0), -1, self.num_classes),self.priors'''"""Multibox LossArgs:predictions (tuple): A tuple containing loc preds, conf preds,and prior boxes from SSD net.conf shape: torch.size(batch_size,num_priors,num_classes)loc shape: torch.size(batch_size,num_priors,4)priors shape: torch.size(num_priors,4)targets (tensor): Ground truth boxes and labels for a batch,shape: [batch_size,num_objs,5] (last idx is the label)."""loc_data, conf_data, priors = predictions #loc是ssd生成的偏置值,priors是PriorBox方法画出来的boxnum = loc_data.size(0) #num = batch_size loc_data=[batchsize,8732,4] priors=[8732xbatchsize,4]priors = priors[:loc_data.size(1), :] #[num_priors,4] 逐个batchsize处理 [8732,4]num_priors = (priors.size(0)) #num_proirs,一个batchsize所有框num_classes = self.num_classes# match priors (default boxes) and ground truth boxesloc_t = torch.Tensor(num, num_priors, 4) #num=batchsizeconf_t = torch.LongTensor(num, num_priors)for idx in range(num): #num = batch_size 一张张图片拿出来truths = targets[idx][:, :-1].data #target 有5个整数,前四个是坐标,框长宽,最后一个是类别labels = targets[idx][:, -1].data #[num_objs,1]defaults = priors.data #[num_priors,4]match(self.threshold, truths, defaults, self.variance, labels, #truths:[num_objs,4]loc_t, conf_t, idx)if self.use_gpu:loc_t = loc_t.cuda()conf_t = conf_t.cuda() #conf_t维度原为[num_priors],通过阈值处理后,变成[