热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

关于pytorch语义分割二分类问题的两种做法

形式1:输出为单通道

分析

即网络的输出 output 为 [batch_size, 1, height, width] 形状。其中 batch_szie 为批量大小,1 表示输出一个通道,heightwidth 与输入图像的高和宽保持一致。

在训练时,输出通道数是 1,网络得到的 output 包含的数值是任意的数。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个sigmoid 函数,使其数值归一化到[0, 1],得到 output1 ,然后让这个 output1target 进行交叉熵计算,得到损失值,反向传播更新网络权重。最终,网络经过学习,会使得 output1 逼近target

训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。首先将输出的 output 通过sigmoid 函数,然后取一个阈值(一般设置为0.5),大于阈值则取1反之则取0,从而得到预测图 predict。后续则是一些评估相关的计算。

代码实现

在这个过程中,训练的损失函数为二进制交叉熵损失函数,然后根据输出是否用到了sigmoid有两种可选的pytorch实现方式:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)

当网络最后一层没有使用sigmoid时,需要使用 torch.nn.BCEWithLogitsLoss() ,顾名思义,在这个函数中,拿到output首先会做一个sigmoid操作,再进行二进制交叉熵计算。上面的操作等价于

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)

当然,你也可以在网络最后一层加上sigmoid操作。从而省去第二行的代码(在预测时也可以省去)。

在预测试时,可用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...

即大于0.5的记为1,小于0.5记为0。

形式2:输出为多通道

分析

即网络的输出 output 为 [batch_size, num_class, height, width] 形状。其中 batch_szie 为批量大小,num_class 表示输出的通道数与分类数量一致,heightwidth 与输入图像的高和宽保持一致。

在训练时,输出通道数是 num_class(这里取2),网络得到的 output 包含的数值是任意的数。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个 softmax 函数,使其数值归一化到[0, 1],得到 output1 ,在各通道中,这个数值加起来会等于1。对于target 他是一个单通道图,首先使用onehot编码,转换成 num_class个通道的图像,每个通道中的取值是根据单通道中的取值计算出来的,例如单通道中的第一个像素取值为1(0<= 1 <=num_class-1,这里num_class=2),那么onehot编码后,在第一个像素的位置上,两个通道的取值分别为0,1。也就是说像素的取值决定了对应序号的通道取1,其他的通道取0,这个非常关键。上面的操作执行完后得到target1,让这个 output1target1 进行交叉熵计算,得到损失值,反向传播更新网路权重。最终,网络经过学习,会使得 output1 逼近target1(在各通道层面上)。

训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。计算 output 中各通道每一个像素位置上,取值最大的那个对应的通道序号,从而得到预测图 predict。后续则是一些评估相关的计算。

代码实现

在这个过程中,则可以使用交叉熵损失函数:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)

根据前面的分析,我们知道,正常的output是 [batch_size, num_class, height, width]形状的,而target是[batch_size, height, width]形状的,需要按照上面的分析进行转换才可以计算交叉熵,而在pytorch中,我们不需要进一步做这个处理,直接使用就可以了。

在预测试时,使用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
predict = output.argmax(dim=1)
...

即得到输出后,在通道方向上找出最大值所在的索引号。

小结

总的来说,我觉得第二种方式更值得推广,一方面不用考虑阈值的选取问题;另一方面,该方法同样适用于多类别的语义分割任务,通用性更强。

参考资料

[1]https://blog.csdn.net/longshaonihaoa/article/details/105253553

[2]https://cuijiahua.com/blog/2020/03/dl-16.html


推荐阅读
  • 这是原文链接:sendingformdata许多情况下,我们使用表单发送数据到服务器。服务器处理数据并返回响应给用户。这看起来很简单,但是 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 在Docker中,将主机目录挂载到容器中作为volume使用时,常常会遇到文件权限问题。这是因为容器内外的UID不同所导致的。本文介绍了解决这个问题的方法,包括使用gosu和suexec工具以及在Dockerfile中配置volume的权限。通过这些方法,可以避免在使用Docker时出现无写权限的情况。 ... [详细]
  • EPICS Archiver Appliance存储waveform记录的尝试及资源需求分析
    本文介绍了EPICS Archiver Appliance存储waveform记录的尝试过程,并分析了其所需的资源容量。通过解决错误提示和调整内存大小,成功存储了波形数据。然后,讨论了储存环逐束团信号的意义,以及通过记录多圈的束团信号进行参数分析的可能性。波形数据的存储需求巨大,每天需要近250G,一年需要90T。然而,储存环逐束团信号具有重要意义,可以揭示出每个束团的纵向振荡频率和模式。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 使用在线工具jsonschema2pojo根据json生成java对象
    本文介绍了使用在线工具jsonschema2pojo根据json生成java对象的方法。通过该工具,用户只需将json字符串复制到输入框中,即可自动将其转换成java对象。该工具还能解析列表式的json数据,并将嵌套在内层的对象也解析出来。本文以请求github的api为例,展示了使用该工具的步骤和效果。 ... [详细]
  • Java验证码——kaptcha的使用配置及样式
    本文介绍了如何使用kaptcha库来实现Java验证码的配置和样式设置,包括pom.xml的依赖配置和web.xml中servlet的配置。 ... [详细]
  • imx6ull开发板驱动MT7601U无线网卡的方法和步骤详解
    本文详细介绍了在imx6ull开发板上驱动MT7601U无线网卡的方法和步骤。首先介绍了开发环境和硬件平台,然后说明了MT7601U驱动已经集成在linux内核的linux-4.x.x/drivers/net/wireless/mediatek/mt7601u文件中。接着介绍了移植mt7601u驱动的过程,包括编译内核和配置设备驱动。最后,列举了关键词和相关信息供读者参考。 ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • 闭包一直是Java社区中争论不断的话题,很多语言都支持闭包这个语言特性,闭包定义了一个依赖于外部环境的自由变量的函数,这个函数能够访问外部环境的变量。本文以JavaScript的一个闭包为例,介绍了闭包的定义和特性。 ... [详细]
  • Ihavethefollowingonhtml我在html上有以下内容<html><head><scriptsrc..3003_Tes ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 【shell】网络处理:判断IP是否在网段、两个ip是否同网段、IP地址范围、网段包含关系
    本文介绍了使用shell脚本判断IP是否在同一网段、判断IP地址是否在某个范围内、计算IP地址范围、判断网段之间的包含关系的方法和原理。通过对IP和掩码进行与计算,可以判断两个IP是否在同一网段。同时,还提供了一段用于验证IP地址的正则表达式和判断特殊IP地址的方法。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有