热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

如何用PyTorch进行语义分割?

木易发自凹非寺量子位报道|公众号QbitAI很久没给大家带来教程资源啦。正值PyTorch1.7更新,那么我们这次便给大家带来一个PyTorch简单实用的教程资源&#

木易 发自 凹非寺 
量子位 报道 | 公众号 QbitAI

很久没给大家带来教程资源啦。

正值PyTorch 1.7更新,那么我们这次便给大家带来一个PyTorch简单实用的教程资源:用PyTorch进行语义分割。

图源:stanford

该教程是基于2020年ECCV Vipriors Chalange Start Code实现了语义分割,并且添加了一些技巧。

友情提示:教程中的所有文件均可以在文末的开源地址获取。

预设置

在开始训练之前,得首先设置一下库、数据集等。

库准备

pip install -r requirements.txt

下载数据集

教程使用的是来自Cityscapes的数据集MiniCity Dataset。

数据集的简单数据分析

将各基准类别进行输入:

之后,便从0-18计数,对各类别进行像素标记:

使用deeplab v3进行基线测试,结果发现次要类别的IoU特别低,这样会导致难以跟背景进行区分。

如下图中所示的墙、栅栏、公共汽车、火车等。

分析结论:数据集存在严重的类别不平衡问题。

训练基准模型

使用来自torchvision的DeepLabV3进行训练。

硬件为4个RTX 2080 Ti GPU (11GB x 4)&#xff0c;如果只有1个GPU或较小的GPU内存&#xff0c;请使用较小的批处理大小&#xff08;<&#61; 8&#xff09;。

python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8;

python baseline.py --save_path baseline_run_deeplabv3_resnet101 --model DeepLabv3_resnet101 --train_size 512 1024 --test_size 512 1024 --crop_size 384 768 --batch_size 8;

损失函数

有3种损失函数可供选择&#xff0c;分别是&#xff1a;交叉熵损失函数&#xff08;Cross-Entropy Loss&#xff09;、类别加权交叉熵损失函数&#xff08;Class-Weighted Cross Entropy Loss&#xff09;和焦点损失函数&#xff08;Focal Loss&#xff09;。

交叉熵损失函数&#xff0c;常用在大多数语义分割场景&#xff0c;但它有一个明显的缺点&#xff0c;那就是对于只用分割前景和背景的时候&#xff0c;当前景像素的数量远远小于背景像素的数量时&#xff0c;模型严重偏向背景&#xff0c;导致效果不好。

# Cross Entropy Loss
python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8;

类别加权交叉熵损失函数是在交叉熵损失函数的基础上为每一个类别添加了一个权重参数&#xff0c;使其在样本数量不均衡的情况下可以获得更好的效果。

# Weighted Cross Entropy Loss
python baseline.py --save_path baseline_run_deeplabv3_resnet50_wce --crop_size 576 1152 --batch_size 8 --loss weighted_ce;

焦点损失函数则更进一步&#xff0c;用来解决难易样本数量不平衡。

# Focal Loss
python baseline.py --save_path baseline_run_deeplabv3_resnet50_focal --crop_size 576 1152 --batch_size 8 --loss focal --focal_gamma 2.0;

归一化层

有4种归一化方法&#xff1a;BN&#xff08;Batch Normalization&#xff09;、IN&#xff08;Instance Normalization&#xff09;、GN&#xff08;Group Normalization&#xff09;和EvoNorm&#xff08;Evolving Normalization&#xff09;。

BN是在batch上&#xff0c;对N、H、W做归一化&#xff0c;而保留通道 C 的维度。BN对较小的batch size效果不好。

# Batch Normalization
python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8;

IN在图像像素上&#xff0c;对H、W做归一化&#xff0c;用在风格化迁移。

# Instance Normalization
python baseline.py --save_path baseline_run_deeplabv3_resnet50_instancenorm --crop_size 576 1152 --batch_size 8 --norm instance;

GN将通道分组&#xff0c;然后再做归一化。

# Group Normalization
python baseline.py --save_path baseline_run_deeplabv3_resnet50_groupnorm --crop_size 576 1152 --batch_size 8 --norm group;

EvoNorm则是4月份由谷歌和DeepMind 联合发布的一项新技术。实验证明&#xff0c;EvoNorms 在多个图像分类模型上效果显著&#xff0c;而且还能很好地迁移到 Mask R-CNN 模型和 BigGAN。

# Evolving Normalization
python baseline.py --save_path baseline_run_deeplabv3_resnet50_evonorm --crop_size 576 1152 --batch_size 8 --norm evo;

数据增强

2种数据增强技术&#xff1a;CutMix、Copy Blob。

  • CutMix

将一部分区域cut掉但不填充0像素&#xff0c;而是随机填充训练集中的其他数据的区域像素值&#xff0c;分类结果按一定的比例分配。

而在这里&#xff0c;则是在原有CutMix的基础上&#xff0c;引入了语义分割。

# CutMix Augmentation
python baseline.py --save_path baseline_run_deeplabv3_resnet50_cutmix --crop_size 576 1152 --batch_size 8 --cutmix;

  • Copy Blob

在 Blob 存储的基础上构建&#xff0c;并通过Copy的方式增强了性能。

另外&#xff0c;如果要解决前面所提到的类别不平衡问题&#xff0c;则可以使用视觉归纳优先的CopyBlob进行增强。

# CopyBlob Augmentation
python baseline.py --save_path baseline_run_deeplabv3_resnet50_copyblob --crop_size 576 1152 --batch_size 8 --copyblob;

推理

训练结束后&#xff0c;对训练完成的模型进行评估。

python baseline.py --save_path baseline_run_deeplabv3_resnet50 --batch_size 4 --predict;

多尺度推断

使用[0.5&#xff0c;0.75&#xff0c;1.0&#xff0c;1.25&#xff0c;1.5&#xff0c;1.75&#xff0c;2.0&#xff0c;2.2]进行多尺度推理。另外&#xff0c;使用H-Flip&#xff0c;同时必须使用单一批次。

# Multi-Scale Inference
python baseline.py --save_path baseline_run_deeplabv3_resnet50 --batch_size 1 --predict --mst;

使用验证集计算度量

计算指标并将结果保存到results.txt中。

python evaluate.py --results baseline_run_deeplabv3_resnet50/results_val --batch_size 1 --predict --mst;

最终结果

最后的单一模型结果是0.6069831962012341&#xff0c;

如果使用了更大的模型或者更大的网络结构&#xff0c;性能可能会有所提高。

另外&#xff0c;如果使用了各种集成模型&#xff0c;性能也会有所提高。

资源地址&#xff1a;
https://github.com/hoya012/semantic-segmentation-tutorial-pytorch

本文系网易新闻•网易号特色内容激励计划签约账号【量子位】原创内容&#xff0c;未经账号授权&#xff0c;禁止随意转载。

一键三连「分享」、「点赞」和「在看」

科技前沿进展日日相见~



推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 本文介绍了一个适用于PHP应用快速接入TRX和TRC20数字资产的开发包,该开发包支持使用自有Tron区块链节点的应用场景,也支持基于Tron官方公共API服务的轻量级部署场景。提供的功能包括生成地址、验证地址、查询余额、交易转账、查询最新区块和查询交易信息等。详细信息可参考tron-php的Github地址:https://github.com/Fenguoz/tron-php。 ... [详细]
  • Android日历提醒软件开源项目分享及使用教程
    本文介绍了一款名为Android日历提醒软件的开源项目,作者分享了该项目的代码和使用教程,并提供了GitHub项目地址。文章详细介绍了该软件的主界面风格、日程信息的分类查看功能,以及添加日程提醒和查看详情的界面。同时,作者还提醒了读者在使用过程中可能遇到的Android6.0权限问题,并提供了解决方法。 ... [详细]
  • Imdevelopinganappwhichneedstogetmusicfilebystreamingforplayinglive.我正在开发一个应用程序,需要通过流 ... [详细]
  • navicat生成er图_实践案例丨ACL2020 KBQA 基于查询图生成回答多跳复杂问题
    摘要:目前复杂问题包括两种:含约束的问题和多跳关系问题。本文对ACL2020KBQA基于查询图生成的方法来回答多跳复杂问题这一论文工作进行了解读 ... [详细]
  • YOLOV4 Pytorch版本训练自建数据集和预测
    1.程序下载本文程序核心部分完全参考开源代码:https:github.comWongKinYiuPyTorch_YOLOv4。只是从一种学习的角度去写了我的代码仓库,在基础上增加 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 关于我们EMQ是一家全球领先的开源物联网基础设施软件供应商,服务新产业周期的IoT&5G、边缘计算与云计算市场,交付全球领先的开源物联网消息服务器和流处理数据 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • 详解 Python 的二元算术运算,为什么说减法只是语法糖?[Python常见问题]
    原题|UnravellingbinaryarithmeticoperationsinPython作者|BrettCannon译者|豌豆花下猫(“Python猫 ... [详细]
  • 3.5.2Calc的公式语法:使用Calc计算一个公式可用是任何能够被Emacs的calc包所识别的代数表达式.注意,在Calc中,的操作符优先级要比*低,因此ab*c会被解释为a ... [详细]
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
author-avatar
cang桑哥哥
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有