热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PointRend原理与代码解析

paper:PointRend:ImageSegmentationasRenderingcode1:https:github.comfacebookr

paper:PointRend: Image Segmentation as Rendering

code1:https://github.com/facebookresearch/detectron2/tree/main/projects/PointRend 

code2:https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend 


创新点

本文的中心思想是将图像分割视为一个渲染问题,具体做法是使用subvision策略来自适应地选择一个非均匀的点集来计算标签。说人话就是针对图像分割中边缘分割不准的情况,提出了一种新的优化方法,具体就是选取网络输出特征图上少数难分像素点,这些难分点大概率分布在物体边缘附近,然后加了一个小的子网络去学习这些难分点的特征,最终提升模型在物体轮廓边缘处的分割精度。


方法介绍

PointRend主要包含三个部分


  1. a point selection strategy
  2. a point-wise feature representation
  3. a point head

Point Selection

推理阶段 对于网络的输出特征图,挑选 \(N\) 个最难分最不确定的点(比如对于只有一类前景的分割问题,概率越接近0.5说明越难分),计算这些难分点的point-wise feature representation,然后根据这些特征去预测这些点的类别。这 \(N\) 个点之外的其它易分点就直接以coarset level即模型输出的最小feature map的预测结果为准。然后上采样,重复该步骤。

训练阶段 推理阶段采用的逐步上采样然后每次都选择 \(N\) 个最难分点的方式不适合训练阶段,因此训练阶段采用non-iterative的方法,首先基于均匀分布随机过采样 \(kN(k>1)\) 个点,然后从中选出 \(\beta N(\beta \in[0,1])\) 个最不确定的点,然后再从剩下的点中基于均匀分布挑出 \((1-\beta)N\) 个点,然后接一个point head subnetwork去学习这 \(N\) 个点的特征,point head的预测和损失计算也只针对这 \(N\) 个点。


Point-wise Representation

Fine-grained features. 为了让PointRend学习精细的分割细节,要从CNN的特征图中提取每个采样点的特征向量,并且要采用浅层的分辨率大的包含丰富细节特征的特征图。下面的实现中采用的是neck输出中分辨率最大的特征图。

Coarse prediction features. 细粒度特征包含了丰富的细节特征,但只有细粒度特征还不够,一是因为当一个点被两个物体的bounding box同时覆盖时,这两个物体在这一点有相同的细粒度特征,但这个点只能被预测为其中一个物体,因为对于实例分割,还需要额外的region-specific特征。二是因为细粒度特征只包含了低维信息,更多的上下文和语义特征可能会有帮助,这对实例分割和语义分割都有帮助。下面的实现中采用的是fpn head的最终预测输出。

将fine-grained特征和coarse prediction特征拼接到一起,就得到了这些采样点的最终特征表示。


Point head

在得到了采样点的特征表示后,PointRend采用了一个多层感知器(MLP)来进行点分割预测,预测每个点的分割类别后,根据对应的标签计算损失进行训练。


代码解析

接下来以mmsegmentation中的PointRend实现为例,讲解一下具体实现。

只有一类前景。假设batch_size=4,input_shape=(4, 3, 480, 480)。backbOne=ResNetV1c,backbone的输出为[(4, 256, 120, 120), (4, 512, 60, 60), (4, 1024, 30, 30), (4, 2048, 15, 15)]。neck=FPN,neck后的输出为[(4, 256, 120, 120), (4, 512, 60, 60), (4, 1024, 30, 30), (4, 2048, 15, 15)]。pointrend中有两个head,因此用cascade_encoder_decoder将两个head串联起来。第一个head是FPN head,借鉴了Panoptic Feature Pyramid Networks中的Semantic FPN,这里就不具体介绍了,输出为(4, 2, 120, 120),然后计算这个head的损失,loss采用的交叉熵损失。

第二个head是point_head,point_head的输入包括neck的最大分辨率输出(4, 256, 120, 120),以及FPN head的输出(4, 2, 120, 120)。

选择难分点,这里prev_output是fpn head的输出

with torch.no_grad():points = self.get_points_train(prev_output, calculate_uncertainty, cfg=train_cfg) # (4,2,120,120) -> (4,2048,2)

评价难分程度的函数如下,具体就是计算每个点top1得分和top2得分的差,差越小说明越难分。注意这里计算的是top2-top1,值为负,所以值越大说明越难分。

def calculate_uncertainty(seg_logits):top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] # (4,2,6144) -> (4,2,6144)return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) # (4,6144) -> (4,1,6144)

具体选择用于训练的难分点实现如下,其中随机采样是通过mmcv中的函数point_sample实现的,而point_sample中首先将随机采样的坐标point_coords由[0, 1]转化到[-1, 1]区间,然后通过F.grid_sample根据归一化的坐标位置进行插值采样,F.grid_sample的用法见F.grid_sample 用法解读_00000cj的博客-CSDN博客。在训练过程中,如上文所述,point selection阶段 \(N=2048,k=3,\beta =0.75\)。实际挑出训练的点有2048个,首先随机采样3x2048=6144个点,然后挑选出最难分的0.75x2048=1536个点,剩下的2048-1536=512个数随机挑选。

def get_points_train(self, seg_logits, uncertainty_func, cfg):"""Sample points for training.Sample points in [0, 1] x [0, 1] coordinate space based on theiruncertainty. The uncertainties are calculated for each point using&#39;uncertainty_func&#39; function that takes point&#39;s logit prediction asinput.Args:seg_logits (Tensor): Semantic segmentation logits, shape (batch_size, num_classes, height, width).uncertainty_func (func): uncertainty calculation function.cfg (dict): Training config of point head.Returns:point_coords (Tensor): A tensor of shape (batch_size, num_points,2) that contains the coordinates of ``num_points`` sampledpoints."""num_points = cfg.num_points # 2048oversample_ratio = cfg.oversample_ratio # 3importance_sample_ratio = cfg.importance_sample_ratio # 0.75assert oversample_ratio >= 1assert 0 <= importance_sample_ratio <= 1batch_size = seg_logits.shape[0] # (4,2,120,120)num_sampled = int(num_points * oversample_ratio) # 2048x3=6144point_coords = torch.rand(batch_size, num_sampled, 2, device=seg_logits.device) # (4,6144,2)point_logits = point_sample(seg_logits, point_coords) # (4,2,6144)# It is crucial to calculate uncertainty based on the sampled# prediction value for the points. Calculating uncertainties of the# coarse predictions first and sampling them for points leads to# incorrect results. To illustrate this: assume uncertainty func(# logits)=-abs(logits), a sampled point between two coarse# predictions with -1 and 1 logits has 0 logits, and therefore 0# uncertainty value. However, if we calculate uncertainties for the# coarse predictions first, both will have -1 uncertainty,# and sampled point will get -1 uncertainty.point_uncertainties = uncertainty_func(point_logits) # (4,1,6144)num_uncertain_points = int(importance_sample_ratio * num_points) # 0.75x2048=1536num_random_points = num_points - num_uncertain_points # 512idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] # (4,1536)shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=seg_logits.device) # (4,), (0,6144,12288,18432)idx += shift[:, None] # (4,1536) += (4,1) -> (4,1536)# (4,6144,2)->(24576,2)[(4,1536)->(6144), :] -> (6144,2) -> (4,1536,2)point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)if num_random_points > 0:rand_point_coords = torch.rand(batch_size, num_random_points, 2, device=seg_logits.device)point_coords = torch.cat((point_coords, rand_point_coords), dim=1) # (4,2048,2)return point_coords

在得到待训练点的坐标后,分别从neck最大分辨率输出(4, 256, 120, 120)和FPN head的预测结果(4, 2, 120, 120)上插值得到对应的fine feature和coarse feature。其中内部实现还是通过point_sample。

fine_grained_point_feats = self._get_fine_grained_point_feats(x, points) # (4,256,2048)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points) # (4,2,2048)

然后将fine feature和coarse feature拼接起来,最终的point head是一个MLP,层数为3,最终再经过一个卷积层得到这2048个点的分类结果。

point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) # (4,2,2048)def forward(self, fine_grained_point_feats, coarse_point_feats):x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) # (4,258,2048)for fc in self.fcs:x = fc(x)if self.coarse_pred_each_layer: # Truex = torch.cat((x, coarse_point_feats), dim=1) # (4,258,2048)return self.cls_seg(x) # (4,2,2048)

在得到难分点的预测结果后,因为采样点的坐标不是整数,特征是从feature map上插值得到的,对应的标签也要插值得到,只不过特征插值时是采用的bilinear,而标签采用的是nearest。point head的loss也是交叉熵损失。 

point_label = point_sample(gt_semantic_seg.float(), # (4,1,480,480)points,mode=&#39;nearest&#39;,align_corners=self.align_corners) # (4,1,2048)
point_label = point_label.squeeze(1).long() # (4,2048)
losses = self.losses(point_logits, point_label)

实验结果 

下图是一些示例,可以看出PointRend对边缘的分割更加精细。

因为只采样少数难分点,而对于大部分易分点比如远离图像边缘的区域,coarse prediction就足够了,因此增加point head后增加的计算量有限,如下 

下面分别是在DeeplabV3和SemanticFPN中加入PointRend,精度都得到了提升 

分割的标注通常不够精确,因此实际效果的提升可能比上表中的更大。 


推荐阅读
  • 利用树莓派畅享落网电台音乐体验
    最近重新拾起了闲置已久的树莓派,这台小巧的开发板已经沉寂了半年多。上个月闲暇时间较多,我决定将其重新启用。恰逢落网电台进行了改版,回忆起之前在树莓派论坛上看到有人用它来播放豆瓣音乐,便萌生了同样的想法。通过一番调试,终于实现了在树莓派上流畅播放落网电台音乐的功能,带来了全新的音乐享受体验。 ... [详细]
  • 本题库精选了Java核心知识点的练习题,旨在帮助学习者巩固和检验对Java理论基础的掌握。其中,选择题部分涵盖了访问控制权限等关键概念,例如,Java语言中仅允许子类或同一包内的类访问的访问权限为protected。此外,题库还包括其他重要知识点,如异常处理、多线程、集合框架等,全面覆盖Java编程的核心内容。 ... [详细]
  • 本文详细介绍了Java反射机制的基本概念、获取Class对象的方法、反射的主要功能及其在实际开发中的应用。通过具体示例,帮助读者更好地理解和使用Java反射。 ... [详细]
  • 本文详细介绍了 PHP 中对象的生命周期、内存管理和魔术方法的使用,包括对象的自动销毁、析构函数的作用以及各种魔术方法的具体应用场景。 ... [详细]
  • 属性类 `Properties` 是 `Hashtable` 类的子类,用于存储键值对形式的数据。该类在 Java 中广泛应用于配置文件的读取与写入,支持字符串类型的键和值。通过 `Properties` 类,开发者可以方便地进行配置信息的管理,确保应用程序的灵活性和可维护性。此外,`Properties` 类还提供了加载和保存属性文件的方法,使其在实际开发中具有较高的实用价值。 ... [详细]
  • 在对WordPress Duplicator插件0.4.4版本的安全评估中,发现其存在跨站脚本(XSS)攻击漏洞。此漏洞可能被利用进行恶意操作,建议用户及时更新至最新版本以确保系统安全。测试方法仅限于安全研究和教学目的,使用时需自行承担风险。漏洞编号:HTB23162。 ... [详细]
  • POJ 2482 星空中的星星:利用线段树与扫描线算法解决
    在《POJ 2482 星空中的星星》问题中,通过运用线段树和扫描线算法,可以高效地解决星星在窗口内的计数问题。该方法不仅能够快速处理大规模数据,还能确保时间复杂度的最优性,适用于各种复杂的星空模拟场景。 ... [详细]
  • 在Java项目中,当两个文件进行互相调用时出现了函数错误。具体问题出现在 `MainFrame.java` 文件中,该文件位于 `cn.javass.bookmgr` 包下,并且导入了 `java.awt.BorderLayout` 和 `java.awt.Event` 等相关类。为了确保项目的正常运行,请求提供专业的解决方案,以解决函数调用中的错误。建议从类路径、依赖关系和方法签名等方面入手,进行全面排查和调试。 ... [详细]
  • 在Ubuntu上安装MySQL时解决缺少libaio.so.1错误及libaio在MySQL中的重要性分析
    在Ubuntu系统上安装MySQL时,遇到了缺少libaio.so.1的错误。本文详细介绍了如何解决这一问题,并深入探讨了libaio库在MySQL性能优化中的重要作用。对于初学者而言,理解这些依赖关系和配置步骤是成功安装和运行MySQL的关键。通过本文的指导,读者可以顺利解决相关问题,并更好地掌握MySQL在Linux环境下的部署与管理。 ... [详细]
  • DRF框架中Serializer反序列化验证机制详解:深入探讨Validators的应用与优化
    在DRF框架的反序列化验证机制中,除了基本的字段类型和长度校验外,还常常需要进行更为复杂的条件限制校验。通过引入`validators`模块,可以实现自定义校验逻辑,如唯一字段校验等。本文将详细探讨`validators`的使用方法及其优化策略,帮助开发者更好地理解和应用这一重要功能。 ... [详细]
  • 深入理解 Java 控制结构的全面指南 ... [详细]
  • 本文介绍了UUID(通用唯一标识符)的概念及其在JavaScript中生成Java兼容UUID的代码实现与优化技巧。UUID是一个128位的唯一标识符,广泛应用于分布式系统中以确保唯一性。文章详细探讨了如何利用JavaScript生成符合Java标准的UUID,并提供了多种优化方法,以提高生成效率和兼容性。 ... [详细]
  • 本文深入探讨了 hCalendar 微格式在事件与时间、地点相关活动标记中的应用。作为微格式系列文章的第四篇,前文已分别介绍了 rel 属性用于定义链接关系、XFN 微格式增强链接的人际关系描述以及 hCard 微格式对个人和组织信息的描述。本次将重点解析 hCalendar 如何通过结构化数据标记,提高事件信息的可读性和互操作性。 ... [详细]
  • 利用 Python 中的 Altair 库实现数据抖动的水平剥离分析 ... [详细]
  • C#编程指南:实现列表与WPF数据网格的高效绑定方法 ... [详细]
author-avatar
红昊子楽楽七_358
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有