一共分为5个步骤,
在介绍具体的代码之前,有几个重要的变量解释如下:
变量名 | 含义 | Shape |
conv_features | Backbone最后一层特征图 | [1,2048,25,34] |
enc_attn_weights | 编码器最后一层的self_attn weights | [1,850,850] |
dec_attn_weights | 解码器最后一层的cross_attn weights | [1,100,850] |
memory | 编码器的输出/解码器的输入特征 | [850,1,256] |
cq | 解码器最后一层self_attn的输出 | [100,1,256] |
pk | 位置编码 | [1,256,25,34] |
pq | 训练好的object queries,即query_embed | [100,256] |
in_proj_weight | 解码器最后一层cross_attn中q和k的线性权重 | [768,256] |
in_proj_bias | 解码器最后一层cross_attn中q和k的偏置 | [768] |
每个步骤的代码如下:
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import requests
import matplotlib.pyplot as pltimport torch
import torchvision.transforms as T
from torch.nn.functional import linear,softmax
torch.set_grad_enabled(False)def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return b# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ----------------------------------------------1. 加载模型及获取训练好的参数---------------------------------------------------
# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 获取训练好的参数
for name, parameters in model.named_parameters():# 获取训练好的object queries,即pq:[100,256]if name == 'query_embed.weight':pq = parameters# 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':in_proj_weight = parametersif name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':in_proj_bias = parameters
# --------------------------------------------2.下载图像并进行预处理和前馈过程--------------------------------------------------
# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# img_path = '/home/wujian/000000039769.jpg'
# im = Image.open(img_path)# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)# propagate through the model
outputs = model(img)# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
# ------------------------------------------------3. 准备存储前馈该图片时的值---------------------------------------------------
# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = [] # 存储detr中的 cq
pk = [] # 存储detr中的 encoder pos
memory = [] # 编码器最后一层的输入/解码器的输入特征# 注册hook
hooks = [# 获取resnet最后一层特征图model.backbone[-2].register_forward_hook(lambda self, input, output: conv_features.append(output)),# 获取encoder的图像特征图memorymodel.transformer.encoder.register_forward_hook(lambda self, input, output: memory.append(output)),# 获取encoder的最后一层layer的self-attn weightsmodel.transformer.encoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: enc_attn_weights.append(output[1])),# 获取decoder的最后一层layer中交叉注意力的 weightsmodel.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_attn_weights.append(output[1])),# 获取decoder最后一层self-attn的输出cqmodel.transformer.decoder.layers[-1].norm1.register_forward_hook(lambda self, input, output: cq.append(output)),# 获取图像特征图的位置编码pkmodel.backbone[-1].register_forward_hook(lambda self, input, output: pk.append(output)),
]# propagate through the model
outputs = model(img)# 用完的hook后删除
for hook in hooks:hook.remove()# don't need the list anymore
conv_features = conv_features[0] # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850] : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850] : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256] # 编码器最后一层的输入/解码器的输入特征cq = cq[0] # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0] # [1,256,25,34]
这里求attn_output_weigths的关键步骤为:
q=cq+pq
k=pk
q=linear(q, in_proj_weight, in_proj_bias)
k=linear(k, in_proj_weight, in_proj_bias)
attn_ouput_weights=torch.bmm(q,k) #[1,8,100,850]分别为8个head的注意力值
# ----------------------------------------4, 求attn_output_weights以绘制各个head的注意力权重------------------------------------
pk = pk.flatten(-2).permute(2,0,1) # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1) # [100,1,256]
q = pq + cqk = pk# 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
q = linear(q, _w, _b)_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
k = linear(k, _w, _b)scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]
# ----------------------------------------------------------5. 画图---------------------------------------------------------
h, w = conv_features['0'].tensors.shape[-2:]fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28)) # [11,2]
colors = COLORS * 100# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):# 可视化decoder的注意力权重ax = ax_i[0]ax.imshow(dec_attn_weights[0, idx].view(h, w))ax.axis('off')ax.set_title(f'query id: {idx.item()}',fontsize = 30)# 可视化框和类别ax = ax_i[1]ax.imshow(im)ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color='blue', linewidth=3))ax.axis('off')ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30)# 分别可视化8个头部的位置特征图for head in range(2, 2 + 8):ax = ax_i[head]ax.imshow(attn_every_heads[0, head-2, idx].view(h,w))ax.axis('off')ax.set_title(f'head:{head-2}',fontsize = 30)fig.tight_layout() # 自动调整子图来使其填充整个画布
plt.show()
[注]:以上代码来自网络
可视化结果:
其中第一行的图就是用dec_attn_weights画出来的
下面是8个head的可视化结果图,由attn_ouput_weights绘制