论文地址:https://arxiv.org/abs/2206.09959https://arxiv.org/abs/2206.09959
代码地址:
GitHub - NVlabs/GCVit: Official PyTorch implementation of Global Context Vision TransformersOfficial PyTorch implementation of Global Context Vision Transformers - GitHub - NVlabs/GCVit: Official PyTorch implementation of Global Context Vision Transformershttps://github.com/NVlabs/GCViT
作者提出了全局上下文Vision Transformer(GCViT),这是一种提高参数和计算利用率的新架构。提出的方法利用全局上下文自注意模块,与局部自注意相结合,有效地建模长期和短期空间交互,而不需要昂贵的操作。
在这项工作中,作者引入Global Context(GC)ViT 网络。提出了一个由局部和全局自注意模块组成的分层ViT架构。在每个阶段,作者使用修改的fused inverted residual 模块来计算全局query token。作者称之为Fused-MBConv 模块,它包含来自不同图像区域的全局上下文信息。本地自注意模块负责对short-range信息进行建模,而全局query token 在所有全局自注意模块之间共享,以与本地key和value进行交互。
论文主要贡献:
- 一种新的分层Transformer模型,称为GCViT,它可以作为各种计算机视觉任务的通用主干网络,如分类、检测、实例分割;
- 一种新颖而简单的设计,由全局自注意和令牌生成模块组成,允许通过捕获全局上下文信息来建模长期依赖关系,从而消除了对高度复杂或复杂操作的需要;
- 如图1所示,实验结果SOTA。
网络结构:
GC ViT 结构
网络结构如图2所示。与之前的一些Transformer架构类似,使用一个层次框架,通过减少空间维度,同时扩大嵌入维数,分别获得几个分辨率(称为阶段)的特征表示。
首先输入图像的分辨率为H X W X 3 ,通过应用一个3×3的卷积层和适当的填充来获得重叠(overlapping)的patch。然后将patch投影到C维嵌入空间中。每个GCViT阶段都由交替的局部和全局自注意模块组成,以提取空间特征。两者都在像Swin Transformer 这样的 local windows 中运行,然而,全局自注意访问 Global Toke Generator (GTG)提取的全局特征。GTG是一个类似于cnn的模块,它在每个阶段只从整个图像中提取一次特征。每个阶段后增加一个下采样模块,空间分辨率减少一半。生成的特征通过平均池化和线性层传递,以为下游任务创建嵌入。
Downsampling
从CNN模型中借用了空间特征收缩的概念,该模型在降维的同时施加了局部性偏差和跨通道通信。
class ReduceSize(nn.Module):def __init__(self, dim,norm_layer=nn.LayerNorm,keep_dim=False):super().__init__()self.conv = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1,groups=dim, bias=False),nn.GELU(),SE(dim, dim),nn.Conv2d(dim, dim, 1, 1, 0, bias=False),)if keep_dim:dim_out = dimelse:dim_out = 2*dimself.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)self.norm2 = norm_layer(dim_out)self.norm1 = norm_layer(dim)def forward(self, x):x = x.contiguous()x = self.norm1(x)x = x.permute(0, 3, 1, 2)x = x + self.conv(x)x = self.reduction(x).permute(0, 2, 3, 1)x = self.norm2(x)return x
Attention
多头自注意是GCViT体系结构中从图像中提取语义信息的核心计算算子。GCViT由局部和全局的自注意模块组成,如图4所示。
Global Query Generator
作者提出包含跨整个输入特征图的信息的全局查询标记(global query tokens),以便与局部键keys和值values特征进行交互。如图5所示:
class FeatExtract(nn.Module):def __init__(self, dim, keep_dim=False):super().__init__()self.conv = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1,groups=dim, bias=False),nn.GELU(),SE(dim, dim),nn.Conv2d(dim, dim, 1, 1, 0, bias=False),)if not keep_dim:self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.keep_dim = keep_dimdef forward(self, x):x = x.contiguous()x = x + self.conv(x)if not self.keep_dim:x = self.pool(x)return x
Global Self-Attention
图4展示了这篇论文主要贡献的思想。
Local MSA:
class WindowAttention(nn.Module):def __init__(self,dim,num_heads,window_size,qkv_bias=True,qk_scale=None,attn_drop=0.,proj_drop=0.,):super().__init__()window_size = (window_size,window_size)self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, q_global):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return
Global MSA:
class WindowAttentionGlobal(nn.Module):def __init__(self, dim,num_heads,window_size,qkv_bias=True,qk_scale=None,attn_drop=0.,proj_drop=0.,):super().__init__()window_size = (window_size,window_size)self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, q_global):B_, N, C = x.shapeB = q_global.shape[0]kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]q_global = q_global.repeat(B_//B, 1, 1, 1)q = q_global.reshape(B_, self.num_heads, N, C // self.num_heads)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x
几种模型结构配置:
实验结果: