热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

脸书用检测变压器(DETR)检测物体

脸书用检测变压器(DETR)检测物体原文:https://www

脸书用检测变压器(DETR)检测物体

原文:https://www . geesforgeks . org/object-detection-with-detection-transformer-dert-by-Facebook/

脸书刚刚于 2020 年 5 月 27 日发布了其最先进的物体检测模型。他们称之为 DERT 代表探测变压器,因为它使用变压器探测物体。这是第一次将变压器与卷积神经网络一起用于此类目标检测任务。还有其他物体检测模型,如 RCNN 系列、 YOLO (你只看一次)和 SSD(单镜头检测),但它们都没有使用变压器来完成这项任务。这个模型最棒的地方在于,由于它使用了一个转换器,它使得架构非常简单,不像前面提到的所有其他具有各种超参数和层的技术。所以没有进一步的告别,让我们开始吧。
什么是物体检测?
给定一张照片如果你需要确定照片是否有一个特定的物体,你可以通过分类来完成。但是如果你想得到那个物体在图像中的位置…嗯,即使这不是一个物体探测任务…这叫做分类和本地化。但是如果图像中有多个对象,并且您想要每个对象的每个位置,那么这就是对象检测。
前面的一些技术试图让一个 RPN(区域提议网络)想出可能包含对象的潜在区域,然后我们可以使用锚框、NMS(非最大抑制)和 IOU 的概念来生成相关的框并识别对象。虽然这些概念起作用,但推理需要一些时间,因此由于其复杂性,无法实现高精度的实时使用。
在高层次上,这使用 CNN,然后使用一个变压器来检测一个对象,它通过一个二分匹配训练对象来检测。这就是它如此简单的主要原因。

来源-https://arxiv.org/pdf/2005.12872.pdf

步骤 1:
我们将图像通过卷积神经网络编码器,因为 CNN 处理图像效果最好。所以通过 CNN 后,图像特征是守恒的。这是具有更多特征通道的图像的高阶表示。
步骤 2:
图像的这种丰富的特征图被提供给变压器编码器-解码器,其输出盒子预测的集合。这些框中的每一个都由一个元组组成。元组将是一个类和一个边界框。注意:这也包括空类或无类及其位置。
现在,这是一个真正的问题,因为在注释中没有对象类被注释为空。比较和处理相邻的相似对象是另一个主要问题,在本文中,使用二分匹配损失来解决这个问题。通过比较每个类和边界框与相应的类和框(包括 none 类,比如说 N 个)来比较损失,注释包括添加的不包含任何内容的部分,以使总框数为 N。预测值与实际值的分配是一对一的分配,从而使总损失最小化。有一种非常著名的算法叫做匈牙利法来计算这些最小匹配。
主要部件:

来源-https://arxiv.org/pdf/2005.12872.pdf

从卷积神经网络提取的主干–特征和位置编码被传递
变压器编码器–变压器自然是一个序列处理单元,出于同样的原因,我们输入张量被展平。它将序列转换成一个同样长的特征序列。
转换器解码器–接收对象查询,因此它是一个解码器,作为调节信息的辅助输入。
预测前馈网络(FFN)–这方面的输出通过一个分类器,该分类器输出前面讨论过的类标签和包围盒输出
评估器:
评估是在 COCO 数据集上完成的,其主要竞争对手是已经统治该类别一段时间的 RCNN 家族,被认为是最经典的对象检测技术。

来源-https://arxiv.org/pdf/2005.12872.pdf

优势:T2】


  • 这个新模型非常简单,你不需要安装任何库来使用它。

  • DETR 在大型物体上表现出明显更好的性能,而不是在可以进一步改进的小型物体上。

  • 好消息是,他们甚至在论文中提供了代码,所以现在我们也将实现它,以了解它真正能够做什么。

代码:

Python 3

# Write Python3 code here
mport torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
  super().__init__()
  # We take only convolutional layers from ResNet-50 model
  self.backbOne= nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
  self.cOnv= nn.Conv2d(2048, hidden_dim, 1)
  self.transformer = nn.Transformer(hidden_dim, heads,
  num_encoder_layers, num_decoder_layers)
  self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
  self.linear_bbox = nn.Linear(hidden_dim, 4)
  self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
  self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  def forward(self, inputs):
  x = self.backbone(inputs)
  h = self.conv(x)
  H , W = h.shape[-2:]
  pos = torch.cat([
  self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
  self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
  h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
  self.query_pos.unsqueeze(1))
  return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
Listing 1: DETR PyTorch inference code. For clarity, it uses learned positional encodings in the encoder instead of fixed, and positional encodings are added to the input
only instead of at each transformer layer. Making these changes requires going beyond
PyTorch implementation of transformers, which hampers readability. The entire code
to reproduce the experiments will be made available before the conference.

我们只从 ResNet-50 模型中提取卷积层
代码取自论文

代码:试着在 colab 上运行这个代码,或者直接进入这个链接,复制并运行完整的文件。

Python 3

import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont

我们将使用 ResNet 101 作为主干架构,我们将直接从 Pytorch Hub 加载该架构。
代号:

Python 3

model = th.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
model.eval()
model = model.cuda()

Python 3

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

在此输入图像的网址。我用过的是 https://i.ytimg.com/vi/vrlX3cwr3ww/maxresdefault.jpg
T1【代号:T3】

Python 3

url = input()

显示图像

Python 3

img = Image.open(requests.get(url, stream=True).raw).resize((800,600)).convert('RGB')
img

代码:

Python 3

img_tens = transform(img).unsqueeze(0).cuda()
with th.no_grad():
  output = model(img_tens)
draw = ImageDraw.Draw(img)
pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
pred_boxes=output['pred_boxes'][0]
max_output = pred_logits.softmax(-1).max(-1)
topk = max_output.values.topk(15)
pred_logits = pred_logits[topk.indices]
pred_boxes = pred_boxes[topk.indices]
pred_logits.shape

代码:

Python 3

for logits, box in zip(pred_logits, pred_boxes):
  cls = logits.argmax()
  if cls >= len(CLASSES):
    continue
  label = CLASSES[cls]
  print(label)
  box = box.cpu() * th.Tensor([800, 600, 800, 600])
  x, y, w, h = box
  x0, x1 = x-w//2, x+w//2
  y0, y1 = y-h//2, y+h//2
  draw.rectangle([x0, y0, x1, y1], outline='red', white')

代码:显示检测到的图像

Python 3

img

这是 colab 笔记本和 github 代码的链接。另外,请随意查看官方 GitHub 了解相同的
缺点:
训练需要很长时间。它在 8 个图形处理器上训练了 6 天。当你把它和这种规模的语言模型进行比较时,这并不重要,因为它们使用了一个转换器,但是仍然。


推荐阅读
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 优化ListView性能
    本文深入探讨了如何通过多种技术手段优化ListView的性能,包括视图复用、ViewHolder模式、分批加载数据、图片优化及内存管理等。这些方法能够显著提升应用的响应速度和用户体验。 ... [详细]
  • 本文探讨了Hive中内部表和外部表的区别及其在HDFS上的路径映射,详细解释了两者的创建、加载及删除操作,并提供了查看表详细信息的方法。通过对比这两种表类型,帮助读者理解如何更好地管理和保护数据。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 数据库内核开发入门 | 搭建研发环境的初步指南
    本课程将带你从零开始,逐步掌握数据库内核开发的基础知识和实践技能,重点介绍如何搭建OceanBase的开发环境。 ... [详细]
  • 本文介绍如何使用 Python 编写程序,检查给定列表中的元素是否形成交替峰值模式。我们将探讨两种不同的方法来实现这一目标,并提供详细的代码示例。 ... [详细]
  • 探讨一个显示数字的故障计算器,它支持两种操作:将当前数字乘以2或减去1。本文将详细介绍如何用最少的操作次数将初始值X转换为目标值Y。 ... [详细]
  • 本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ... [详细]
  • 本文详细探讨了Netty中Future及其子类的设计与实现,包括其在并发编程中的作用和具体应用场景。我们将介绍Future的继承体系、关键方法的实现细节,并讨论如何通过监听器和回调机制来处理异步任务的结果。 ... [详细]
  • 如何在PostgreSQL中查看数据表
    本文将指导您使用pgAdmin工具连接到PostgreSQL数据库,并展示如何浏览和查找其中的数据表。通过简单的步骤,您可以轻松访问所需的表结构和数据。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
author-avatar
江苏经贸学院
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有