热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

脸书用检测变压器(DETR)检测物体

脸书用检测变压器(DETR)检测物体原文:https://www

脸书用检测变压器(DETR)检测物体

原文:https://www . geesforgeks . org/object-detection-with-detection-transformer-dert-by-Facebook/

脸书刚刚于 2020 年 5 月 27 日发布了其最先进的物体检测模型。他们称之为 DERT 代表探测变压器,因为它使用变压器探测物体。这是第一次将变压器与卷积神经网络一起用于此类目标检测任务。还有其他物体检测模型,如 RCNN 系列、 YOLO (你只看一次)和 SSD(单镜头检测),但它们都没有使用变压器来完成这项任务。这个模型最棒的地方在于,由于它使用了一个转换器,它使得架构非常简单,不像前面提到的所有其他具有各种超参数和层的技术。所以没有进一步的告别,让我们开始吧。
什么是物体检测?
给定一张照片如果你需要确定照片是否有一个特定的物体,你可以通过分类来完成。但是如果你想得到那个物体在图像中的位置…嗯,即使这不是一个物体探测任务…这叫做分类和本地化。但是如果图像中有多个对象,并且您想要每个对象的每个位置,那么这就是对象检测。
前面的一些技术试图让一个 RPN(区域提议网络)想出可能包含对象的潜在区域,然后我们可以使用锚框、NMS(非最大抑制)和 IOU 的概念来生成相关的框并识别对象。虽然这些概念起作用,但推理需要一些时间,因此由于其复杂性,无法实现高精度的实时使用。
在高层次上,这使用 CNN,然后使用一个变压器来检测一个对象,它通过一个二分匹配训练对象来检测。这就是它如此简单的主要原因。

来源-https://arxiv.org/pdf/2005.12872.pdf

步骤 1:
我们将图像通过卷积神经网络编码器,因为 CNN 处理图像效果最好。所以通过 CNN 后,图像特征是守恒的。这是具有更多特征通道的图像的高阶表示。
步骤 2:
图像的这种丰富的特征图被提供给变压器编码器-解码器,其输出盒子预测的集合。这些框中的每一个都由一个元组组成。元组将是一个类和一个边界框。注意:这也包括空类或无类及其位置。
现在,这是一个真正的问题,因为在注释中没有对象类被注释为空。比较和处理相邻的相似对象是另一个主要问题,在本文中,使用二分匹配损失来解决这个问题。通过比较每个类和边界框与相应的类和框(包括 none 类,比如说 N 个)来比较损失,注释包括添加的不包含任何内容的部分,以使总框数为 N。预测值与实际值的分配是一对一的分配,从而使总损失最小化。有一种非常著名的算法叫做匈牙利法来计算这些最小匹配。
主要部件:

来源-https://arxiv.org/pdf/2005.12872.pdf

从卷积神经网络提取的主干–特征和位置编码被传递
变压器编码器–变压器自然是一个序列处理单元,出于同样的原因,我们输入张量被展平。它将序列转换成一个同样长的特征序列。
转换器解码器–接收对象查询,因此它是一个解码器,作为调节信息的辅助输入。
预测前馈网络(FFN)–这方面的输出通过一个分类器,该分类器输出前面讨论过的类标签和包围盒输出
评估器:
评估是在 COCO 数据集上完成的,其主要竞争对手是已经统治该类别一段时间的 RCNN 家族,被认为是最经典的对象检测技术。

来源-https://arxiv.org/pdf/2005.12872.pdf

优势:T2】


  • 这个新模型非常简单,你不需要安装任何库来使用它。

  • DETR 在大型物体上表现出明显更好的性能,而不是在可以进一步改进的小型物体上。

  • 好消息是,他们甚至在论文中提供了代码,所以现在我们也将实现它,以了解它真正能够做什么。

代码:

Python 3

# Write Python3 code here
mport torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
  super().__init__()
  # We take only convolutional layers from ResNet-50 model
  self.backbOne= nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
  self.cOnv= nn.Conv2d(2048, hidden_dim, 1)
  self.transformer = nn.Transformer(hidden_dim, heads,
  num_encoder_layers, num_decoder_layers)
  self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
  self.linear_bbox = nn.Linear(hidden_dim, 4)
  self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
  self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  def forward(self, inputs):
  x = self.backbone(inputs)
  h = self.conv(x)
  H , W = h.shape[-2:]
  pos = torch.cat([
  self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
  self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
  h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
  self.query_pos.unsqueeze(1))
  return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
Listing 1: DETR PyTorch inference code. For clarity, it uses learned positional encodings in the encoder instead of fixed, and positional encodings are added to the input
only instead of at each transformer layer. Making these changes requires going beyond
PyTorch implementation of transformers, which hampers readability. The entire code
to reproduce the experiments will be made available before the conference.

我们只从 ResNet-50 模型中提取卷积层
代码取自论文

代码:试着在 colab 上运行这个代码,或者直接进入这个链接,复制并运行完整的文件。

Python 3

import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont

我们将使用 ResNet 101 作为主干架构,我们将直接从 Pytorch Hub 加载该架构。
代号:

Python 3

model = th.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
model.eval()
model = model.cuda()

Python 3

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

在此输入图像的网址。我用过的是 https://i.ytimg.com/vi/vrlX3cwr3ww/maxresdefault.jpg
T1【代号:T3】

Python 3

url = input()

显示图像

Python 3

img = Image.open(requests.get(url, stream=True).raw).resize((800,600)).convert('RGB')
img

代码:

Python 3

img_tens = transform(img).unsqueeze(0).cuda()
with th.no_grad():
  output = model(img_tens)
draw = ImageDraw.Draw(img)
pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
pred_boxes=output['pred_boxes'][0]
max_output = pred_logits.softmax(-1).max(-1)
topk = max_output.values.topk(15)
pred_logits = pred_logits[topk.indices]
pred_boxes = pred_boxes[topk.indices]
pred_logits.shape

代码:

Python 3

for logits, box in zip(pred_logits, pred_boxes):
  cls = logits.argmax()
  if cls >= len(CLASSES):
    continue
  label = CLASSES[cls]
  print(label)
  box = box.cpu() * th.Tensor([800, 600, 800, 600])
  x, y, w, h = box
  x0, x1 = x-w//2, x+w//2
  y0, y1 = y-h//2, y+h//2
  draw.rectangle([x0, y0, x1, y1], outline='red', white')

代码:显示检测到的图像

Python 3

img

这是 colab 笔记本和 github 代码的链接。另外,请随意查看官方 GitHub 了解相同的
缺点:
训练需要很长时间。它在 8 个图形处理器上训练了 6 天。当你把它和这种规模的语言模型进行比较时,这并不重要,因为它们使用了一个转换器,但是仍然。


推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • Python语法上的区别及注意事项
    本文介绍了Python2x和Python3x在语法上的区别,包括print语句的变化、除法运算结果的不同、raw_input函数的替代、class写法的变化等。同时还介绍了Python脚本的解释程序的指定方法,以及在不同版本的Python中如何执行脚本。对于想要学习Python的人来说,本文提供了一些注意事项和技巧。 ... [详细]
  • ZSI.generate.Wsdl2PythonError: unsupported local simpleType restriction ... [详细]
  • 学习SLAM的女生,很酷
    本文介绍了学习SLAM的女生的故事,她们选择SLAM作为研究方向,面临各种学习挑战,但坚持不懈,最终获得成功。文章鼓励未来想走科研道路的女生勇敢追求自己的梦想,同时提到了一位正在英国攻读硕士学位的女生与SLAM结缘的经历。 ... [详细]
  • Nginx使用(server参数配置)
    本文介绍了Nginx的使用,重点讲解了server参数配置,包括端口号、主机名、根目录等内容。同时,还介绍了Nginx的反向代理功能。 ... [详细]
  • Metasploit攻击渗透实践
    本文介绍了Metasploit攻击渗透实践的内容和要求,包括主动攻击、针对浏览器和客户端的攻击,以及成功应用辅助模块的实践过程。其中涉及使用Hydra在不知道密码的情况下攻击metsploit2靶机获取密码,以及攻击浏览器中的tomcat服务的具体步骤。同时还讲解了爆破密码的方法和设置攻击目标主机的相关参数。 ... [详细]
  • 本文介绍了Oracle数据库中tnsnames.ora文件的作用和配置方法。tnsnames.ora文件在数据库启动过程中会被读取,用于解析LOCAL_LISTENER,并且与侦听无关。文章还提供了配置LOCAL_LISTENER和1522端口的示例,并展示了listener.ora文件的内容。 ... [详细]
  • 本文介绍了多因子选股模型在实际中的构建步骤,包括风险源分析、因子筛选和体系构建,并进行了模拟实证回测。在风险源分析中,从宏观、行业、公司和特殊因素四个角度分析了影响资产价格的因素。具体包括宏观经济运行和宏经济政策对证券市场的影响,以及行业类型、行业生命周期和行业政策对股票价格的影响。 ... [详细]
  • 关于我们EMQ是一家全球领先的开源物联网基础设施软件供应商,服务新产业周期的IoT&5G、边缘计算与云计算市场,交付全球领先的开源物联网消息服务器和流处理数据 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
author-avatar
江苏经贸学院
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有