热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【损失函数合集】ECCV2016CenterLoss

前言感觉一年下来了解了很多损失函数,准备做一个【损失函数合集】不定期更新一些损失函数的解读,相信这个系列会成为工程派非常感兴趣的一个专题,
前言

感觉一年下来了解了很多损失函数,准备做一个【损失函数合集】不定期更新一些损失函数的解读,相信这个系列会成为工程派非常感兴趣的一个专题,毕竟这可能是工程量最小同时可能会在某个任务上取得较大提升的最简单有效的方法了。话不多说,今天来介绍一下ECCV 2016的Center Loss。论文链接见附录。

介绍

我们知道在利用卷积神经网络(CNN)进行特征提取然后做分类任务的时候,最常用的就是Softmax Loss了。而Softmax Loss对CNN的分类特征没有什么强约束,只要最后可以把样本分开就可以了,没有明确的定义某个分类特征xxx和它的类别yyy需要有什么关系,也即是对xxx的约束并不强。当给CNN传入一些对抗样本的时候,CNN会表现出分类困难。原因就是CNN对对抗样本提取的特征基本处于最后一层分类特征的边界上,使得类别的区分度低。论文在Figure2为我们展示了Mnist数据集的特征分布图(这是降维可视化之后的),如下所示:

在这里插入图片描述
Figure2(a)表示训练集的特征,可以看到不同的类别之间的特征是有明显的间隔的,所以分类器很容易区分样本的类别。而Figure2(b)表示的是测试集的特征,可以看到,有些测试集样本的特征位于分类的边界上,这使得分类器非常难以判断这些样本的具体类别。这种现象用更专业的话来说就是分类器的特征类间相似度小,而类内相似度大。这和聚类里面的思想有一点像,聚类的时候需要尽可能的让同一个簇里面的相似度更大大,而簇之间的相似度更小,这样聚类的结果才更好。

更正式地,论文指出CNN学习到的特征应当具有“类内紧凑性”和“类间可分离性”,意思上面大概理解过了,下面就用这两个名词来描述了。

相关工作

为了对CNN学习到的特征的“类内紧凑性”进行约束,Contrastive Loss和Triple Loss被提出,这两个Loss这几天会接着介绍,今天先介绍Center Loss,不懂这两个个Loss也不太不影响理解Center Loss。简单来说,Contrastive Loss的思想是最小化一对同类样本特征之间的距离,最大化一对不同样本特征之间的距离。而Triplets loss则把一对变成3个。这两个Loss均要把样本进行组合来进行优化,导致样本的数量急剧增加(O(n2),O(n3)O(n^2), O(n^3)O(n2),O(n3)),因此加长了训练时间并且训练难度也大大增加,为了改善这一缺点Center Loss被提出。

Center Loss

Center Loss可以直接对样本特征之间的距离进行约束。Center Loss添加的约束是,特征与相同类别的平均特征(有"Center"的味道了)的距离要足够小,即要求同类特征要更加接近它们的中心点,用公式来表达:

LC=12∑i=1m∣∣xi−cyi∣∣22L_{C}=\frac{1}{2}\sum_{i=1}^m||x_i-c_{y_i}||_2^2LC=21i=1mxicyi22

其中xix_ixi表示第iii个样本被CNN提取的特征,cyic_{y_i}cyi表示第iii类样本的平均特征,mmm表示样本数。从公式可以看到如果我们要计算同一类别所有样本的特征然后求平均值的话是不太可能的,因为数据量多了之后这个计算复杂度非常高。因为论文提出用mini-batch中每个类别的平均特征近似不同类别所有样本的平均特征,有点类似于BN层。

关于LcL_cLc的梯度和cyic_{y_i}cyi的更新公式如下:

在这里插入图片描述其中δ(x)\delta(x)δ(x)函数代表的是当x为真时返回111,否则为000。而分母的111是防止min-batch中没有类别jjj的样本而导致除000异常。然后论文还设置了一个cjc_jcj的更新速率α\alphaα,控制cjc_jcj的更新速度。最后训练的总损失函数为:
在这里插入图片描述最后,Center Loss的算法表示如下:

在这里插入图片描述

实验结果

对于不同的λ\lambdaλ,网络提取的特征具有不同的区分度,具体如Figure3所示,可以看到随着λ\lambdaλ的增加,特征的“类内紧凑性”越高。

在这里插入图片描述
同时作者还研究了不同的λ\lambdaλα\alphaα对人脸识别任务的影响,如Figure5所示,可以看出当λ=0.003\lambda=0.003λ=0.003α=0.5\alpha=0.5α=0.5的时候对人脸识别是效果最好的。

在这里插入图片描述

代码实现

import torch
import torch.nn as nn
from torch.autograd.function import Functionclass CenterLoss(nn.Module):def __init__(self, num_classes, feat_dim, size_average=True):super(CenterLoss, self).__init__()self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))self.centerlossfunc = CenterlossFunc.applyself.feat_dim = feat_dimself.size_average = size_averagedef forward(self, label, feat):batch_size = feat.size(0)feat = feat.view(batch_size, -1)# To check the dim of centers and featuresif feat.size(1) != self.feat_dim:raise ValueError("Center's dim: {0} should be equal to input feature's \dim: {1}".format(self.feat_dim,feat.size(1)))batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)loss = self.centerlossfunc(feat, label, self.centers, batch_size_tensor)return lossclass CenterlossFunc(Function):@staticmethoddef forward(ctx, feature, label, centers, batch_size):ctx.save_for_backward(feature, label, centers, batch_size)centers_batch = centers.index_select(0, label.long())return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size@staticmethoddef backward(ctx, grad_output):feature, label, centers, batch_size = ctx.saved_tensorscenters_batch = centers.index_select(0, label.long())diff = centers_batch - feature# init every iterationcounts = centers.new_ones(centers.size(0))ones = centers.new_ones(label.size(0))grad_centers = centers.new_zeros(centers.size())counts = counts.scatter_add_(0, label.long(), ones)grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)grad_centers = grad_centers/counts.view(-1, 1)return - grad_output * diff / batch_size, None, grad_centers / batch_size, Nonedef main(test_cuda=False):print('-'*80)device = torch.device("cuda" if test_cuda else "cpu")ct = CenterLoss(10,2,size_average=True).to(device)y = torch.Tensor([0,0,2,1]).to(device)feat = torch.zeros(4,2).to(device).requires_grad_()print (list(ct.parameters()))print (ct.centers.grad)out = ct(y,feat)print(out.item())out.backward()print(ct.centers.grad)print(feat.grad)if __name__ == '__main__':torch.manual_seed(999)main(test_cuda=False)if torch.cuda.is_available():main(test_cuda=True)

结论

这篇推文给大家介绍了Center Loss,可以改善分类的时候的“类内紧凑性”小的问题(这同时也放大了“类间可分离性”),是一个值得工程尝试的损失函数。

附录
  • 论文原文:http://ydwen.github.io/papers/WenECCV16.pdf
  • 代码实现:https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py

推荐阅读
  • 目标检测算法之RetinaNet(引入Focal Loss)
  • 目标检测算法之AAAI2019 Oral论文GHM Loss
  • 目标检测算法之CVPR2019 GIoU Loss
  • 目标检测算法之AAAI 2020 DIoU Loss 已开源(YOLOV3涨近3个点)


欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧

有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:

在这里插入图片描述


推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。 ... [详细]
  • php更新数据库字段的函数是,php更新数据库字段的函数是 ... [详细]
  • 在《Cocos2d-x学习笔记:基础概念解析与内存管理机制深入探讨》中,详细介绍了Cocos2d-x的基础概念,并深入分析了其内存管理机制。特别是针对Boost库引入的智能指针管理方法进行了详细的讲解,例如在处理鱼的运动过程中,可以通过编写自定义函数来动态计算角度变化,利用CallFunc回调机制实现高效的游戏逻辑控制。此外,文章还探讨了如何通过智能指针优化资源管理和避免内存泄漏,为开发者提供了实用的编程技巧和最佳实践。 ... [详细]
  • PTArchiver工作原理详解与应用分析
    PTArchiver工作原理及其应用分析本文详细解析了PTArchiver的工作机制,探讨了其在数据归档和管理中的应用。PTArchiver通过高效的压缩算法和灵活的存储策略,实现了对大规模数据的高效管理和长期保存。文章还介绍了其在企业级数据备份、历史数据迁移等场景中的实际应用案例,为用户提供了实用的操作建议和技术支持。 ... [详细]
  • 技术日志:使用 Ruby 爬虫抓取拉勾网职位数据并生成词云分析报告
    技术日志:使用 Ruby 爬虫抓取拉勾网职位数据并生成词云分析报告 ... [详细]
  • 从2019年AI顶级会议最佳论文,探索深度学习的理论根基与前沿进展 ... [详细]
  • 浅层神经网络解析:本文详细探讨了两层神经网络(即一个输入层、一个隐藏层和一个输出层)的结构与工作原理。通过吴恩达教授的课程,读者将深入了解浅层神经网络的基本概念、参数初始化方法以及前向传播和反向传播的具体实现步骤。此外,文章还介绍了如何利用这些基础知识解决实际问题,并提供了丰富的实例和代码示例。 ... [详细]
  • 深度森林算法解析:特征选择与确定能力分析
    本文深入探讨了深度森林算法在特征选择与确定方面的能力。提出了一种名为EncoderForest(简称eForest)的创新方法,作为首个基于决策树的编码器模型,它在处理高维数据时展现出卓越的性能,为特征选择提供了新的视角和工具。 ... [详细]
  • 魅族Flyme 7正式发布:全面解析与亮点介绍
    在22日晚的发布会上,魅族不仅推出了m15、15和15 Plus三款新机型,还正式发布了全新的Flyme 7系统。Flyme 7在保持流畅体验的基础上,进一步增强了功能性和实用性,为用户带来更加丰富的使用体验。首批适配包已准备就绪,将逐步推送给现有设备。 ... [详细]
  • 通过使用 `pandas` 库中的 `scatter_matrix` 函数,可以有效地绘制出多个特征之间的两两关系。该函数不仅能够生成散点图矩阵,还能通过参数如 `frame`、`alpha`、`c`、`figsize` 和 `ax` 等进行自定义设置,以满足不同的可视化需求。此外,`diagonal` 参数允许用户选择对角线上的图表类型,例如直方图或密度图,从而提供更多的数据洞察。 ... [详细]
  • 本文详细介绍了批处理技术的基本概念及其在实际应用中的重要性。首先,对简单的批处理内部命令进行了概述,重点讲解了Echo命令的功能,包括如何打开或关闭回显功能以及显示消息。如果没有指定任何参数,Echo命令会显示当前的回显设置。此外,文章还探讨了批处理技术在自动化任务执行、系统管理等领域的广泛应用,为读者提供了丰富的实践案例和技术指导。 ... [详细]
  • Vue + WangEditor 遇到 “无法读取未定义的属性 'menus'” 错误的解决方案 ... [详细]
  • 如何在C#中配置组合框的背景颜色? ... [详细]
  • 在 Vue 应用开发中,页面状态管理和跨页面数据传递是常见需求。本文将详细介绍 Vue Router 提供的两种有效方式,帮助开发者高效地实现页面间的数据交互与状态同步,同时分享一些最佳实践和注意事项。 ... [详细]
author-avatar
boy31455349
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有