热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

mmcls多标签分类实战(二):resnet多标签分类

上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。

上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。

Config配置
mmcls是通过config来配置整个网络结构的。如下,我使用的是resnet18,因为数据中有5个属性,所以输出的num_classes=5。需要注意的是,head要选用head=dict(type=‘MultiLabelLinearClsHead’)。这是因为多标签分类,在进入loss前,应该用sigmoid激活,将pred的值归一化。如果使用softmax,会出现属性互斥的现象(因为pred在dim=1上,sum=1)。对于Multi-label问题,应该使用F.binary_cross_entropy_with_logits损失。

model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
# type='LinearClsHead',
type='MultiLabelLinearClsHead',
num_classes=5,
in_channels=512,
# loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
# topk=(1, 5),
))

自定义dataset
为了读取数据,并将label转变为loss可以计算的格式,我们需要重新定 def load_annotations(self):为了不增加类,定义了self.multi_label的flag来分离Multi-label与Multi-class。我们在txt中的label是一个num,比如你有5个属性类别,label可能是1,3,而BCE中label需要的格式是[1,0,1],因此我们需要转化一下。

def load_annotations(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')
data_infos = []
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
temp_label = np.zeros(len(self.CLASSES))

if not self.multi_label:
info['gt_label'] = np.array(gt_label, dtype=np.int64)
else:
##multi-label classify
if len(gt_label) == 1:
temp_label[np.array(gt_label, dtype=np.int64)] = 1
info['gt_label'] = temp_label
else:
for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):
temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1
info['gt_label'] = temp_label

data_infos.append(info)
return data_infos

接下来就可以进行多标签的训练了。


推荐阅读
  • 机器学习中的相似度度量与模型优化
    本文探讨了机器学习中常见的相似度度量方法,包括余弦相似度、欧氏距离和马氏距离,并详细介绍了如何通过选择合适的模型复杂度和正则化来提高模型的泛化能力。此外,文章还涵盖了模型评估的各种方法和指标,以及不同分类器的工作原理和应用场景。 ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • golang常用库:配置文件解析库/管理工具viper使用
    golang常用库:配置文件解析库管理工具-viper使用-一、viper简介viper配置管理解析库,是由大神SteveFrancia开发,他在google领导着golang的 ... [详细]
  • 使用Numpy实现无外部库依赖的双线性插值图像缩放
    本文介绍如何仅使用Numpy库,通过双线性插值方法实现图像的高效缩放,避免了对OpenCV等图像处理库的依赖。文中详细解释了算法原理,并提供了完整的代码示例。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ... [详细]
  • C++实现经典排序算法
    本文详细介绍了七种经典的排序算法及其性能分析。每种算法的平均、最坏和最好情况的时间复杂度、辅助空间需求以及稳定性都被列出,帮助读者全面了解这些排序方法的特点。 ... [详细]
  • 深入理解 SQL 视图、存储过程与事务
    本文详细介绍了SQL中的视图、存储过程和事务的概念及应用。视图为用户提供了一种灵活的数据查询方式,存储过程则封装了复杂的SQL逻辑,而事务确保了数据库操作的完整性和一致性。 ... [详细]
  • 构建基于BERT的中文NL2SQL模型:一个简明的基准
    本文探讨了将自然语言转换为SQL语句(NL2SQL)的任务,这是人工智能领域中一项非常实用的研究方向。文章介绍了笔者在公司举办的首届中文NL2SQL挑战赛中的实践,该比赛提供了金融和通用领域的表格数据,并标注了对应的自然语言与SQL语句对,旨在训练准确的NL2SQL模型。 ... [详细]
  • 本文详细介绍了Java中org.eclipse.ui.forms.widgets.ExpandableComposite类的addExpansionListener()方法,并提供了多个实际代码示例,帮助开发者更好地理解和使用该方法。这些示例来源于多个知名开源项目,具有很高的参考价值。 ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文详细介绍了如何构建一个高效的UI管理系统,集中处理UI页面的打开、关闭、层级管理和页面跳转等问题。通过UIManager统一管理外部切换逻辑,实现功能逻辑分散化和代码复用,支持多人协作开发。 ... [详细]
  • 本文介绍了如何通过扩展 UnityGUI 创建自定义和复合控件,以满足特定的用户界面需求。内容涵盖简单和静态复合控件的实现,并展示了如何创建复杂的 RGB 滑块。 ... [详细]
  • 本文详细介绍了中央电视台电影频道的节目预告,并通过专业工具分析了其加载方式,确保用户能够获取最准确的电视节目信息。 ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有