热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

CornerNet代码解析——损失函数

CornerNet代码解析——损失函数文章目录CornerNet代码解析——损失函数前言总体损失1、Heatmap的损失2、Embedding的损失3、Offset的损失前言今天

CornerNet代码解析——损失函数


文章目录

  • CornerNet代码解析——损失函数
    • 前言
    • 总体损失
      • 1、Heatmap的损失
      • 2、Embedding的损失
      • 3、Offset的损失


前言

今天要解析的是CornerNet的Loss层源码,论文中Loss的解析在这:CornerNet的损失函数原理


总体损失

总体的损失函数如下图所示,三个输出分别对应三部分损失,每部分损失有着对应的权重。接下来分别讲述每一块的损失。

源码中将Loss写成一个类:class AELoss,在CornerNet\models\py_utils\kp.py中.

class AELoss(nn.Module):def __init__(self, pull_weight=1, push_weight=1, regr_weight=1, focal_loss=_neg_loss):super(AELoss, self).__init__()# pull_weight = αself.pull_weight = pull_weight# push_weight = βself.push_weight = push_weight# regr_weight = γself.regr_weight = regr_weight# 这其实就是heatmap的lossself.focal_loss = focal_loss# 这其实就是embedding的lossself.ae_loss = _ae_loss# 这其实就是offset的lossself.regr_loss = _regr_lossdef forward(self, outs, targets):stride = 6# ::跳着选'''首先明确两个输入:outs和targetsouts:这是网络的预测结果,outs是一个列表,列表维度为12,outs[0::stride]这些是表示列表的切片操作,意思是隔stride(6)个跳着选。举个例子outs = [1,2,3,4,5,6,7,8,9,10,11,12],outs[0::6]=[1, 7],其实这12个事6个两两成对,也就是左上角的heatmap有两个,右下角的heatmap有两个左上角的embedding有两个,右下角的embedding有两个,左上角的offset有两个,右下角的offset有两个,共12个,为什么要两份?应该跟上面的nstack有关,上述的nstack=2,所以循环出来outs不是6,而是12,映射到论文就是跟这句话:we also add intermediate supervision in training。这是中继监督,具体是啥我也还在看。也就是说下面的6个都是列表,每个列表里面都含有两个tensor,具体维度如下:'''# 两个都是[batch_size, 类别数, 128, 128]tl_heats = outs[0::stride]# 两个都是[batch_size, 类别数, 128, 128]br_heats = outs[1::stride]# 两个都是[batch_size, 128, 1]tl_tags = outs[2::stride]# 两个都是[batch_size, 128, 1]br_tags = outs[3::stride]# 两个都是[batch_size, 128, 2]tl_regrs = outs[4::stride]# 两个都是[batch_size, 128, 2]br_regrs = outs[5::stride]'''targets是gt,标准答案,也是个列表,但就只有下面5个,没有两份具体维度如下'''# [batch_size, 类别数, 128, 128]gt_tl_heat = targets[0]# [batch_size, 类别数, 128, 128]gt_br_heat = targets[1]# [3, 128]gt_mask = targets[2]# [3, 128, 2]gt_tl_regr = targets[3]# [3, 128, 2]gt_br_regr = targets[4]

上述就是传入的预测值和真实值,Loss也就是计算预测的和真实之间的误差,当Loss值越小,那么说明网络预测的结果越好。接下去有了预测和真实值,具体分析三个部分的Loss。


1、Heatmap的损失

Heatmap损失的理论理解在这,接下来是源码理解:

这部分代码在CornerNet\models\py_utils\kp.py中

# focal lossfocal_loss = 0# 到这里将heatmap经过sigmoid,将值映射到0-1之间,变成keypoint的响应值,还是列表,# 维度还是[batch_size, 类别数, 128, 128]tl_heats = [_sigmoid(t) for t in tl_heats]br_heats = [_sigmoid(b) for b in br_heats]# 在CornerNet\models\py_utils\kp_utils.py中详细讲述了focal_loss,这个focal loss就是_neg_loss,形参有体现focal_loss += self.focal_loss(tl_heats, gt_tl_heat)focal_loss += self.focal_loss(br_heats, gt_br_heat)

接着去到CornerNet\models\py_utils\kp_utils.py中详细讲述focal_loss:

'''
首先清楚函数的输入:
preds是列表:(2,),表示一个列表中含两个tensor,每个tensor的维度是(batch_size, 类别数, 128, 128)
gt是tensor:(batch_size, 类别数, 128, 128)
'''

def _neg_loss(preds, gt):# pos_inds是0、1tensor,维度[3,7,128,128]。# eq函数是遍历gt这个tensor每个element,和1比较,如果等于1,则返回1,否则返回0pos_inds = gt.eq(1)# otherwise则是表明ycij第c个通道的(i,j)坐标上值不为1# 遍历gt这个tensor每个element,和1比较,如果小于1,则返回1,否则返回0neg_inds = gt.lt(1)# 总结下上面两个变量:上面这两个0-1位置互补# 回头看这两个变量,再结合公式1,公式1后面有两个判断条件:if ycij=1 and otherwise# 这里就是那两个判断条件,ycij=1表示第c个通道的(i,j)坐标上值为1,也即是gt中这个位置有目标# 也就是pos_inds是ycij=1,neg_inds是otherwise# torch.pow是次幂函数,其中gt[neg_inds]表示取出neg_inds中值为1的gt的值# 所以gt[neg_inds]就变成一个向量了,那么维度就等于neg_inds中有多少为1的# 可以neg_inds.sum()看看,1 - gt[neg_inds]就是单纯的用1减去每个element,# 然后每个element开4次方,就成了neg_weights,这个neg_weights是一维向量# 把gt中每个小于1的数字取出来,然后用1减去,在开方,那不是更小了,# 就是原来就很小,现在又降权。# gt[neg_inds]就是公式(1)中的Ycij# neg_weights就是公式(1)中的(1-ycij)^β,β就是4neg_weights = torch.pow(1 - gt[neg_inds], 4)loss = 0# 循环2次,因为preds是一个列表,有2部分,每部分放着一个tensor,每个tensor的# 维度为[batch_size,类别数,128,128],也就是pred维度为[batch_size,类别数,128,128]for pred in preds:# 首先记住pos_inds中的1就是gt中有目标的地方,neg_inds中的1是gt中没有目标的地方# 将gt认为有目标的地方,pred也按这个地方取出数值,变成向量,pos_inds有多少个1,# pos_pred就多少维(一行向量)pos_pred = pred[pos_inds]# 将gt认为没有目标的地方,pred也按这个地方取出数值,变成向量,neg_inds有多少个1,# neg_pred就多少维(一行向量)neg_pred = pred[neg_inds]# 以上出现的pos_xxx, neg_xxx,命名的意思就是正样本positive和负样本negative# 这里对应的是论文中的公式(1),也就是heatmap的loss# 可以先根据公式把相应的变量确认下:pos_pred就是公式中的Pcij。# neg_pred就是公式中的要经过二维高斯的Pcij,neg_weights就是(1-ycij)^βpos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights# gt的那个tensor中,值为1的个数,num_pos对应公式(1)中的Nnum_pos = pos_inds.float().sum()# 累加pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()# pos_pred是一维的。统计pos_pred中的元素个数,单纯的数个数而已,# 就算pos_pred中值为0的,也算一个if pos_pred.nelement() == 0:loss = loss - neg_losselse:# 用减号体现公式(1)中的-1loss = loss - (pos_loss + neg_loss) / num_pos# 返回最终的heatmap的lossreturn loss

2、Embedding的损失

Heatmap损失的理论理解在这,接下来是源码理解:

接着回到CornerNet\models\py_utils\kp.py,看怎么调用embedding的loss:

# tag loss# 初始化为0pull_loss = 0push_loss = 0# tl_tags、br_tags是列表,里面有两个tensor,每个tensor的维度为[batch_size, 128, 1]# 论文中说到的embedding是一维向量。也就是说,维度表示:一个batch_size一张图,用128*1的矩阵表示??# 那么这个for循环,循环2次,每次进去的是[batch_size, 128, 1]的tl_tag, br_tagfor tl_tag, br_tag in zip(tl_tags, br_tags):pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)pull_loss += pullpush_loss += push# 算出来的loss乘以相应的权重pull_loss = self.pull_weight * pull_losspush_loss = self.push_weight * push_loss

接着去到CornerNet\models\py_utils\kp_utils.py中详细讲述ae_loss:

'''
embedding的损失
输入:tag0、tag1为左上右下各一个[batch_size, 128, 1]的tensor,再来一个gt中的mask,这个mask是
0、1矩阵,维度[batch_size, 128],也就是一张图用128维来表示??????
'''

def _ae_loss(tag0, tag1, mask):# mask是[batch_size, 128],这个就是第一维全部相加(sum),就是把每个batch的128个数字相加,所以num的# 维度是[batch_size, 1],1是128个数字的值相加变成一个数字,而mask还是0-1矩阵,所以这个num代表了# 每张图有多少个1.这个num代表公式(4)和(5)中的Nnum = mask.sum(dim=1, keepdim=True).float()# 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度# 所以tag0和tag1的维度变成了[batch_size, 128],和mask一样# tag0就是公式(4)中的etktag0 = tag0.squeeze()# tag0就是公式(4)中的ebktag1 = tag1.squeeze()# 单纯的求平均而已,这个tag_mean对应公式(4)和(5)中的ek,维度不变tag_mean = (tag0 + tag1) / 2# 这里能够体现是同类别的,因为累加只有一次,也就是Lpull用来缩小# 同类别左上右下角点的embedding vector的距离# 公式(4)前半段tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4)# 这句能体现累加,这里tag0已经是单个数字tag0 = tag0[mask].sum()# 公式(4)后半段tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4)# 这句能体现累加,这里tag1已经是单个数字tag1 = tag1[mask].sum()# 总的Lpullpull = tag0 + tag1# Lpush# 这里能够体现是不同类别的,因为累加有两次,公式(5)中的j不等于k,也就是Lpush用来扩大# 不同类别左上右下角点的embedding vector的距离# 这时候mask的维度由[3,128]-->[3,128,128]mask = mask.unsqueeze(1) + mask.unsqueeze(2)# 遍历mask这个tensor每个element,和2比较,如果等于2,则返回1,否则返回0,但为啥是2呢?mask = mask.eq(2)# num的维度[3, 1]-->[3, 1, 1]num = num.unsqueeze(2)# num2的维度[3, 1, 1],num2表示公式(5)中的N(N-1)num2 = (num - 1) * num# dist是公式(5)中绝对值之间的运算# dist维度[3, 128, 128]=[3, 1, 128]-[3, 128, 1]dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)# 1表示公式(5)三角形dist = 1 - torch.abs(dist)# 公式(5)就是relu,所以计算方式直接套reludist = nn.functional.relu(dist, inplace=True)dist = dist - 1 / (num + 1e-4)dist = dist / (num2 + 1e-4)# 这时候mask的维度[3,128,128],dist维度[3,128,128]dist = dist[mask]# sum之后就变成一个数字了push = dist.sum()# 返回两个loss,两个tensor的数字return pull, push

3、Offset的损失

Heatmap损失的理论理解在这,接下来是源码理解:

接着回到CornerNet\models\py_utils\kp.py,看怎么调用offset的loss:

# offsets lossregr_loss = 0# tl_regrs、br_regrs是列表,里面有两个tensor,每个tensor的维度为[batch_size, 128, 2]# 维度表示:一个batch_size一张图,用128*2的矩阵表示??# 那么这个for循环,循环2次,每次进去的是[batch_size, 128, 2]的tl_regr, br_regrfor tl_regr, br_regr in zip(tl_regrs, br_regrs):regr_loss += self.regr_loss(tl_regr, gt_tl_regr, gt_mask)regr_loss += self.regr_loss(br_regr, gt_br_regr, gt_mask)regr_loss = self.regr_weight * regr_loss# 总的lossloss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats)# unsqueeze(i) 表示将第i维设置为1维return loss.unsqueeze(0)

接着去到CornerNet\models\py_utils\kp_utils.py中详细讲述regr_loss:

'''
输入:regr偏移量,维度[batch_size, 128, 2],gt_regr维度[batch_size, 128, 2]
mask维度[batch_size, 128]
'''

def _regr_loss(regr, gt_regr, mask):# 公式(3)的Nnum = mask.float().sum()# mask.unsqueeze(2)维度[batch_size, 128, 1]# mask的维度[batch_size, 128, 2]mask = mask.unsqueeze(2).expand_as(gt_regr)# 取出mask中1对应的位置,然后在预测的偏移量和真实的偏移量中取出这些位置的值# 此时二者的维度变为一维向量regr = regr[mask]gt_regr = gt_regr[mask]# 直接调用自带的SmoothL1Lossregr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)# 最后除Nregr_loss = regr_loss / (num + 1e-4)return regr_loss

推荐阅读
  • 本文讨论了在Windows 8上安装gvim中插件时出现的错误加载问题。作者将EasyMotion插件放在了正确的位置,但加载时却出现了错误。作者提供了下载链接和之前放置插件的位置,并列出了出现的错误信息。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了解决java开源项目apache commons email简单使用报错的方法,包括使用正确的JAR包和正确的代码配置,以及相关参数的设置。详细介绍了如何使用apache commons email发送邮件。 ... [详细]
  • 现在比较流行使用静态网站生成器来搭建网站,博客产品着陆页微信转发页面等。但每次都需要对服务器进行配置,也是一个重复但繁琐的工作。使用DockerWeb,只需5分钟就能搭建一个基于D ... [详细]
  • pc电脑如何投屏到电视?DLNA主要步骤通过DLNA连接,使用WindowsMediaPlayer的流媒体播放举例:电脑和电视机都是连接的 ... [详细]
  • 微信公众号:内核小王子关注可了解更多关于数据库,JVM内核相关的知识;如果你有任何疑问也可以加我pigpdong[^1]jvm一行代码是怎么运行的首先,java代码会被编译成字 ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文讨论了一个关于正则的困惑,即为什么一个函数会获取parent下所有的节点。同时提出了问题是否是正则表达式写错了。 ... [详细]
  • 精讲代理设计模式
    代理设计模式为其他对象提供一种代理以控制对这个对象的访问。代理模式实现原理代理模式主要包含三个角色,即抽象主题角色(Subject)、委托类角色(被代理角色ÿ ... [详细]
  • docker安装到基本使用
    记录docker概念,安装及入门日常使用Docker安装查看官方文档,在"Debian上安装Docker",其他平台在"这里查 ... [详细]
author-avatar
loloyoyo555
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有