计算机视觉中的Transformer

Transformer在CV中的全面渗透
DETR CLIP SAM 目标检测 多模态 视频Transformer Mask2Former Segment Anything

一、概述:从NLP到CV的架构革命

2017年Google团队在"Attention Is All You Need"中提出Transformer架构,最初用于机器翻译。其核心创新——自注意力机制(Self-Attention)通过计算序列中每个位置与其他位置的注意力权重来捕获长距离依赖关系,彻底取代了RNN的时序递推结构。

2020年,Google提出Vision Transformer(ViT),将图像切分为固定大小的patch序列,直接用标准Transformer编码器处理,在ImageNet上取得了与SOTA CNN相当的结果,标志着Transformer正式进入CV领域。此后,Transformer在检测、分割、跟踪、视频理解、多模态等各细分方向全面开花,形成了"transformerless, no SOTA"的格局。

Transformer核心优势为何能迁移到CV?

  • 全局感受野: CNN受限于局部卷积核,Transformer的self-attention一次建模整个图像或特征图
  • 动态权重: 卷积核参数固定,注意力权重随输入动态变化,表达能力更强
  • 统一架构: 检测、分割、跟踪可以用相同的编码器-解码器范式表达,简化pipeline
  • 规模化潜力: Transformer参数量和数据量增长时性能持续提升(scaling law)

本笔记系统梳理Transformer在各CV子方向中的代表性工作,涵盖检测、跟踪、分割、视频、多模态及高效化部署六大方面,每个方向均附核心代码片段以加深理解。

二、检测Transformer:DETR系列

2.1 DETR:目标检测即集合预测

2020年Facebook提出的DETR(Detection Transformer)是第一个将Transformer应用于目标检测的工作。其核心思想是将检测视为一个集合预测(set prediction)问题:模型直接输出一个固定大小的预测集合,通过二分图匹配与GT集合对齐,完全摒弃了锚框(Anchor)、非极大值抑制(NMS)等传统检测pipeline。

核心架构组件

import torch import torch.nn as nn from torch.nn import TransformerEncoder, TransformerDecoder class DETR(nn.Module): def __init__(self, num_classes=91, num_queries=100, d_model=256): super().__init__() # CNN Backbone: 提取特征 self.backbone = nn.Sequential( nn.Conv2d(3, 64, 7, 2, 3), nn.ReLU(), nn.Conv2d(64, 128, 3, 2, 1), nn.ReLU(), nn.Conv2d(128, d_model, 3, 2, 1), nn.ReLU(), ) # Transformer Encoder encoder_layer = nn.TransformerEncoderLayer(d_model, nhead=8, batch_first=True) self.encoder = TransformerEncoder(encoder_layer, num_layers=6) # Transformer Decoder decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True) self.decoder = TransformerDecoder(decoder_layer, num_layers=6) # 可学习的object queries self.query_embed = nn.Embedding(num_queries, d_model) # 预测头 self.class_head = nn.Linear(d_model, num_classes + 1) # +1 for no-object self.bbox_head = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 4) ) def forward(self, x): # x: (B, 3, H, W) features = self.backbone(x) # (B, d, h, w) B, D, H, W = features.shape features = features.flatten(2).transpose(1, 2) # (B, hw, d) # Encoder memory = self.encoder(features) # (B, hw, d) # Decoder: object queries 与 encoder 特征交互 queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1) tgt = torch.zeros_like(queries) # 初始decoder输入为0 outputs = self.decoder(tgt, memory, query_pos=queries) # (B, N, d) # 预测 class_logits = self.class_head(outputs) # (B, N, 92) bbox_coords = self.bbox_head(outputs).sigmoid() # (B, N, 4) return {'pred_logits': class_logits, 'pred_boxes': bbox_coords}

2.2 匈牙利匹配损失与二分图匹配

DETR最大的创新之一是使用匈牙利算法(Hungarian Algorithm)求解预测集合与GT集合之间的最优二分图匹配。这是一个组合优化问题:给定N个预测和M个GT(M ≤ N),找到使总匹配代价最小的双射。

from scipy.optimize import linear_sum_assignment def hungarian_matching(pred_logits, pred_boxes, gt_labels, gt_boxes): """ pred_logits: (N, C) N=100个预测, C=类别数 pred_boxes: (N, 4) [cx, cy, w, h] 归一化 gt_labels: (M,) GT类别 gt_boxes: (M, 4) GT边框 returns: indices 匹配索引对 """ N, C = pred_logits.shape M = len(gt_labels) cost_matrix = torch.zeros((N, M), device=pred_logits.device) for i in range(N): for j in range(M): # 类别代价: 预测为gt类别的负对数概率 cls_cost = -pred_logits[i, gt_labels[j]] # L1 bbox代价 l1_cost = torch.sum(torch.abs(pred_boxes[i] - gt_boxes[j])) # GIoU代价 giou_cost = -compute_giou(pred_boxes[i], gt_boxes[j]) cost_matrix[i, j] = cls_cost + l1_cost + giou_cost # 匈牙利算法求解最优匹配 row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy()) return row_ind, col_ind # 预测索引 -> GT索引

匹配完成后,损失函数由三部分组成:类别交叉熵损失(匹配上的预测)、L1边框回归损失GIoU损失。未匹配上的预测被分类为"no object"(背景类)。

2.3 DETR系列演进

模型 年份 核心改进 关键贡献
DETR 2020 集合预测+匈牙利匹配 首个端到端检测Transformer
Deformable DETR 2021 可变形注意力 收敛速度快10倍,小目标性能提升
DAB-DETR 2022 动态Anchor Boxes queries显式表示为4D anchor框
DN-DETR 2022 去噪训练 添加带噪声GT加速训练收敛
DINO 2023 混合查询选择+去噪+对比学习 SOTA性能,COCO 64.5 AP
RT-DETR 2024 实时DETR 端到端实时检测,媲美YOLO速度

DETR系列的核心设计哲学

  • Anchor-free: 不需要手动设计锚框大小和比例,object queries自动学习
  • 无NMS: 匈牙利匹配天然去重,一个GT只匹配一个预测,无需后处理
  • 编码器-解码器: 编码器负责全局上下文建模,解码器负责从全局中提取特定物体信息
  • 可学习queries: 每个query可被理解为"在图像中找某类物体"的隐式指令
# Deformable DETR 可变形注意力关键实现 class DeformableAttention(nn.Module): """只在参考点周围采样K个偏移位置, 大大降低注意力复杂度 O(HW) -> O(K)""" def __init__(self, d_model=256, n_heads=8, n_points=4): super().__init__() self.n_heads = n_heads self.n_points = n_points # 学习偏移量 self.offset_conv = nn.Linear(d_model, n_heads * n_points * 2) # 学习注意力权重 self.attn_linear = nn.Linear(d_model, n_heads * n_points) def forward(self, query, reference_points, value): # query: (B, Nq, d), reference_points: (B, Nq, 2) 归一化坐标 # value: (B, Nv, d) 通常是encoder输出 offsets = self.offset_conv(query).view(B, Nq, self.n_heads, self.n_points, 2) sampled_points = reference_points[:, :, None, None, :] + offsets # 双线性插值采样特征 sampled_features = bilinear_sample(value, sampled_points) attn_weights = self.attn_linear(query).softmax(dim=-1) output = (sampled_features * attn_weights).sum(dim=-2) return output # (B, Nq, d) 复杂度 O(Nq * K) 而非 O(Nq * Nv)

三、跟踪Transformer:TrackFormer与MOTR

3.1 从检测到跟踪:时间维度的扩展

多目标跟踪(MOT)的传统范式是检测后跟踪(tracking-by-detection):先用检测器逐帧检测物体,再用关联算法(如DeepSORT中的匈牙利匹配+卡尔曼滤波)将检测结果关联成轨迹。这种方法流水线复杂、需手动调参且错误会累积。

TrackFormerMOTR将跟踪视为时间上的集合预测,在DETR的基础上引入track queries,在帧间传递物体状态的隐式表示,实现了真正的端到端联合检测与跟踪。

3.2 TrackFormer:检测+跟踪合一

TrackFormer的核心思想是:每一帧的decoder中同时存在两类queries——检测queries(detection queries)用于发现新出现的物体,跟踪queries(track queries)从上一帧传播而来,负责跟踪已有物体。这种设计让检测和跟踪在一个统一的transformer框架内完成。

class TrackFormer(nn.Module): def __init__(self, num_classes=1, num_queries=300, d_model=256): super().__init__() self.backbone = resnet50() self.transformer = Transformer(d_model=d_model) self.query_embed = nn.Embedding(num_queries, d_model) # 检测queries self.class_head = nn.Linear(d_model, num_classes + 1) self.bbox_head = MLP(d_model, d_model, 4, 3) @torch.no_grad() def init_track_queries(self, prev_outputs): """从上帧输出的预测中初始化track queries""" track_queries = prev_outputs['queries'][prev_outputs['exists']] return track_queries def forward_single_frame(self, x, track_queries=None): features = self.backbone(x) # Encoder memory = self.transformer.encoder(features) # Decoder: 检测queries + 跟踪queries 合并输入 detect_queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1) if track_queries is not None: queries = torch.cat([detect_queries, track_queries], dim=1) else: queries = detect_queries outputs = self.transformer.decoder(queries, memory) logits = self.class_head(outputs) boxes = self.bbox_head(outputs).sigmoid() return logits, boxes, outputs # outputs作为下一帧的track queries

3.3 MOTR:时间查询传递机制

MOTR进一步形式化了时间查询传递(Temporal Query Propagation)机制,提出了tracklet-aware标签分配策略。在训练时,将同一轨迹的所有帧作为一个单元处理,每帧的track queries从前一帧传播,并通过匈牙利匹配保持identity一致性。

MOTR的帧间关联核心流程

  1. 第t帧: Decoder输出包含N个预测(新检测 + 已有跟踪目标)
  2. Query更新: 对匹配到GT的预测,用其decoder输出更新对应track query
  3. Query传播: 更新后的track queries传入第t+1帧decoder
  4. 身份保持: 同一track query在不同帧对应同一物体,天然保持ID一致性
  5. 轨迹终止: 当track query连续多帧预测为"no object"时终止轨迹
class MOTRDecoder(nn.Module): def forward(self, tgt, memory, track_queries, query_pos=None): """ tgt: 当前帧decoder输入 (初始为0) memory: encoder输出的记忆特征 track_queries: 从上一帧传播过来的跟踪查询 """ # 交叉注意力层 for layer in self.layers: # Self-attention: queries之间交互,抑制重复检测 tgt = self.self_attn(tgt + query_pos) # Cross-attention: queries 查询 encoder 特征 tgt = self.cross_attn(tgt, memory, query_pos) # FFN tgt = self.ffn(tgt) # 输出tgt作为新的queries,其中匹配到物体的部分会更新track queries new_track_queries = torch.where( self.matched_mask.unsqueeze(-1), tgt[:, :num_tracks], # 匹配上的用新输出更新 track_queries # 未匹配上的保持 ) return tgt, new_track_queries

跟踪Transformer的核心优势

  • 端到端: 检测+关联在一个网络中完成,无需独立卡尔曼滤波或ReID分支
  • 隐式关联: Track queries隐式编码了物体的外观和运动信息
  • 长期记忆: Query跨帧传播天然具有记忆能力,可处理遮挡和短暂消失
  • 联合优化: 检测和跟踪损失一起反向传播,避免pipeline错误累积

四、分割Transformer:MaskFormer与Mask2Former

4.1 MaskFormer:通用分割的统一范式

传统的语义分割、实例分割和全景分割使用不同的模型架构。2021年提出的MaskFormer首次将三种分割任务统一为掩码分类(mask classification)范式:模型输出一组二进制掩码和对应的类别预测,通过简单的后处理即可支持任意分割任务。

MaskFormer架构三阶段

  1. 像素级特征提取: CNN/Transformer backbone提取多尺度特征图
  2. Transformer解码器: N个object queries通过交叉注意力从特征图中解码出N个掩码嵌入
  3. 掩码生成: 每个掩码嵌入与像素特征做点积,经sigmoid得到二进制掩码
class MaskFormer(nn.Module): def __init__(self, num_queries=100, d_model=256, num_classes=133): super().__init__() self.backbone = resnet50() self.mask_embed_head = nn.Sequential( nn.Conv2d(d_model, d_model, 3, 1, 1), nn.ReLU(), nn.Conv2d(d_model, d_model, 1) ) self.query_embed = nn.Embedding(num_queries, d_model) decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) self.class_head = nn.Linear(d_model, num_classes) def forward(self, x): # 1. 像素特征 pixel_feats = self.backbone(x) # (B, d, H, W) mask_embeds = self.mask_embed_head(pixel_feats) # (B, d, H, W) # 2. Transformer解码 B, D, H, W = pixel_feats.shape pixel_feats_flat = pixel_feats.flatten(2).permute(2, 0, 1) # (HW, B, d) queries = self.query_embed.weight.unsqueeze(1).expand(-1, B, -1) # (N, B, d) tgt = torch.zeros_like(queries) decoder_out = self.decoder(tgt, pixel_feats_flat, memory_key_padding_mask=None, query_pos=queries) # (N, B, d) # 3. 掩码和类别预测 decoder_out = decoder_out.transpose(0, 1) # (B, N, d) class_logits = self.class_head(decoder_out) # (B, N, C) # 掩码: 每个query的嵌入与像素特征点积 mask_logits = torch.einsum('bnd,bdhw->bnhw', decoder_out, mask_embeds) # (B, N, H, W) return {'class_logits': class_logits, 'mask_logits': mask_logits}

4.2 Mask2Former:掩码注意力与逐层解码

Mask2Former在MaskFormer基础上引入了掩码注意力(Masked Attention)逐层解码(per-layer decoding),显著提升了分割质量和训练效率。

掩码注意力(Masked Attention)

传统Transformer解码器的交叉注意力中,每个query关注所有像素位置。而Mask2Former的掩码注意力将query的注意力范围限制在其预测掩码的前景区域内,使query专注于其负责的物体区域,避免背景干扰。该机制使得可以使用高分辨率特征图(1/8或1/4分辨率)而不会带来过高的计算量。

class MaskedCrossAttention(nn.Module): """掩码交叉注意力:仅在前景区域计算注意力""" def forward(self, query, key, value, mask_pred): """ query: (B, N, d) N个query key: (B, HW, d) 像素特征 mask_pred: (B, N, H, W) 当前层预测的掩码logits """ B, N, d = query.shape H, W = mask_pred.shape[-2:] # 将mask_pred展平并sigmoid得到注意力掩码 attn_mask = mask_pred.view(B, N, -1).sigmoid() # (B, N, HW) # 只保留top-k位置(稀疏注意力) _, topk_idx = attn_mask.topk(k=int(0.3 * H * W), dim=-1) # 前30%位置 # 收集key和value中被选中的位置 key_sampled = gather_topk(key, topk_idx) # (B, N, K, d) value_sampled = gather_topk(value, topk_idx) # (B, N, K, d) # 标准注意力计算 attn_weights = torch.softmax(query @ key_sampled / d**0.5, dim=-1) output = attn_weights @ value_sampled # (B, N, d) return output

逐层解码与多尺度特征

Mask2Former采用多尺度特征金字塔(1/32, 1/16, 1/8分辨率),在每一层decoder中轮流使用不同尺度的特征。每个decoder层的输出不仅用于预测当前层的掩码,还作为下一层query的输入,形成逐层精化的效果。

特性 MaskFormer Mask2Former
特征分辨率 单尺度 1/32 多尺度 1/32, 1/16, 1/8
注意力机制 标准交叉注意力 掩码注意力(前30%位置)
解码器层数 6层 3层(带逐层预测)
训练效率 收敛慢 快3倍
语义分割mIoU 53.9(ADE20K) 57.8(ADE20K)
全景分割PQ 46.5(COCO) 53.6(COCO)

统一分割架构的意义

  • 语义/实例/全景统一: 同一模型、同一权重可处理所有分割任务
  • 掩码分类范式: N个二进制掩码+类别预测,比逐像素分类更灵活
  • 掩码注意力: 将分割结果反馈到注意力计算中,形成自修正闭环
  • 少样本友好: 新类别只需增加输出类数,无需修改架构

五、视频Transformer:时空建模新范式

5.1 TimeSformer:分离时空注意力

视频理解的核心挑战是同时建模空间(帧内)和时间(帧间)依赖。2021年Facebook提出的TimeSformer(Time-Space Transformer)提出分离时空注意力(Divided Space-Time Attention),将3D注意力分解为两个子层:空间注意力(同一帧内patch之间的交互)和时间注意力(不同帧同一位置patch之间的交互)。

class TimeSformerBlock(nn.Module): """分离时空注意力块""" def __init__(self, d_model=768, nhead=12): super().__init__() # 空间注意力: 在同一帧内交互 self.space_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) # 时间注意力: 在不同帧的相同空间位置交互 self.time_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model) ) def forward(self, x): """ x: (B, T, N, d) T帧, 每帧N个patch """ B, T, N, d = x.shape # === 空间注意力:在N维度上计算(不同patch之间)=== x_space = x.reshape(B * T, N, d) # (B*T, N, d) x_space = self.norm1(x_space + self.space_attn(x_space, x_space, x_space)[0]) # === 时间注意力:在T维度上计算(不同帧之间)=== x_time = x_space.view(B, T, N, d) x_time = x_time.permute(0, 2, 1, 3) # (B, N, T, d) x_time = x_time.reshape(B * N, T, d) # (B*N, T, d) x_time = self.norm2(x_time + self.time_attn(x_time, x_time, x_time)[0]) # === FFN === x_time = x_time.reshape(B, N, T, d).permute(0, 2, 1, 3) x_out = x_time + self.ffn(x_time) # (B, T, N, d) return x_out

5.2 ViViT:因子分解编码器

Google的ViViT(Video Vision Transformer)探索了多种视频tokenization和因子分解策略。其核心思路是将视频的时空维度分解处理,减少自注意力的计算量。ViViT提出了四种模型变体,其中效果最好的是因子分解编码器(Factorised Encoder):先用空间Transformer独立处理每帧,再在帧间使用时间Transformer融合。

5.3 VideoMAE:视频掩码自编码器

VideoMAE(Video Masked Autoencoder)将MAE的思路扩展到视频领域。它采用极高比例的掩码率(90%-95%),只保留少量可见patch,迫使编码器学习强大的时空表征。解码器则负责从可见patch重建被掩码的patch。

class VideoMAE(nn.Module): """Video Masked Autoencoder 核心逻辑""" def __init__(self, encoder, decoder, mask_ratio=0.9): super().__init__() self.encoder = encoder # 只处理可见patch self.decoder = decoder # 处理全部位置(含掩码token) self.mask_ratio = mask_ratio def forward(self, video): """ video: (B, T, C, H, W) """ # 1. 将视频切分为时空patch patches = self.patchify(video) # (B, N, d) N = T * H' * W' N = patches.shape[1] # 2. 随机掩码: 保留 1-mask_ratio 比例的patch num_keep = int(N * (1 - self.mask_ratio)) ids_shuffle = torch.rand(N).argsort() ids_keep, ids_mask = ids_shuffle[:num_keep], ids_shuffle[num_keep:] # 3. 编码器只处理可见patch x_keep = patches[:, ids_keep, :] encoded = self.encoder(x_keep) # (B, num_keep, d_enc) # 4. 解码器: 填充掩码token + 位置编码 mask_tokens = self.mask_token.repeat(B, N - num_keep, 1) x_full = torch.cat([encoded, mask_tokens], dim=1) # (B, N, d_dec) # 按原始顺序排列 x_full = x_full.gather(1, ids_shuffle.argsort().unsqueeze(-1).expand(-1, -1, d)) reconstructed = self.decoder(x_full) # (B, N, p*p*3) # 5. 损失: 只在掩码位置计算MSE loss = F.mse_loss(reconstructed[:, ids_mask], patches[:, ids_mask]) return loss def patchify(self, video): """将视频切分为patch序列""" # 实际实现会包含tubelet embedding pass

TubeViT:Tubelet嵌入

不同于ViViT将每帧独立patch化,TubeViT沿时间维度将相邻帧的对应patch合并为一个tubelet(时空管道),作为基本处理单元。这种设计更好地保留了时空连续性,减少了token数量。例如输入16帧,tubelet size为2×16×16时,帧数维度减少到8个tubelet,计算量减半。

模型 时空建模方式 参数量 Kinetics-400 Top1
TimeSformer-L 分离时空注意力 430M 80.7%
ViViT-L (FE) 因子分解编码器 310M 81.3%
VideoMAE-H 掩码自编码器 640M 84.6%
TubeViT-L Tubelet嵌入+时空注意力 380M 82.1%

六、多模态Transformer:CLIP、BLIP-2与SAM

6.1 CLIP:对比图文预训练

2021年OpenAI提出的CLIP(Contrastive Language-Image Pre-training)是视觉-语言模型的里程碑。它使用4亿图文对进行对比学习:图像编码器和文本编码器分别将图像和文本映射到共享嵌入空间,通过对比损失拉近匹配图文对的距离、推远不匹配对的距离。

class CLIP(nn.Module): """简化版CLIP双塔结构""" def __init__(self, d_embed=512, temp=0.07): super().__init__() self.image_encoder = ViT(embed_dim=d_embed) # 图像编码器 self.text_encoder = TextTransformer(embed_dim=d_embed) # 文本编码器 self.temp = nn.Parameter(torch.tensor(temp)) # 温度系数 def forward(self, images, texts): """ images: (B, 3, H, W) texts: (B, L) tokenized文本 """ # 分别编码 I_feat = self.image_encoder(images) # (B, d) T_feat = self.text_encoder(texts) # (B, d) # L2归一化 I_feat = F.normalize(I_feat, dim=-1) T_feat = F.normalize(T_feat, dim=-1) # 对比学习:计算相似度矩阵 logits = I_feat @ T_feat.T / self.temp # (B, B) 对角线为匹配对 # 对称交叉熵损失 labels = torch.arange(B, device=images.device) loss_i = F.cross_entropy(logits, labels) # image->text loss_t = F.cross_entropy(logits.T, labels) # text->image loss = (loss_i + loss_t) / 2 return loss

零样本分类

CLIP最惊艳的能力是零样本分类(zero-shot classification):无需任何训练数据,只需构造prompt模板(如"a photo of a {类别名}"),将文本描述和图像分别编码,选择相似度最高的类别。在ImageNet上,CLIP的zero-shot准确率达到76.2%,超过有监督的ResNet-50。

def zero_shot_classify(clip_model, image, class_names, device): """CLIP零样本分类""" # 构造文本prompt prompts = [f"a photo of a {name}" for name in class_names] text_tokens = tokenize(prompts).to(device) with torch.no_grad(): # 图像编码 image_features = clip_model.encode_image(image.unsqueeze(0)) # 文本编码 text_features = clip_model.encode_text(text_tokens) # 相似度计算 image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) logits = (image_features @ text_features.T) * 100.0 probs = logits.softmax(dim=-1) return probs # 概率分布

6.2 ViLT & Flamingo & BLIP-2

ViLT(Vision-Language Transformer)是首个将视觉和文本输入统一为token序列并用单一Transformer处理的模型。它使用线性投影将图像patch映射到与文本embedding相同的空间,使视觉-语言交互发生在transformer的每一层,而非仅在fusion层。

Flamingo(DeepMind, 2022)是面向少样本学习的大规模视觉语言模型。其核心创新是感知器重采样器(Perceiver Resampler),将可变数量的视觉特征压缩为固定数量的token,然后通过门控交叉注意力(Gated Cross-Attention)注入到冻结的语言模型中。

BLIP-2(Salesforce, 2023)提出了Q-Former(Querying Transformer)架构,用一组可学习的query tokens通过交叉注意力从冻结的图像编码器中提取视觉特征,再输入到冻结的大语言模型(如OPT、Vicuna)中。这种模块化设计使得只用少量可训练参数就能桥接视觉和语言模型。

BLIP-2 / Q-Former架构

  1. 冻结的图像编码器: ViT-L/14提取视觉特征
  2. Q-Former: 一组可学习的query tokens(通常32个)通过交叉注意力从ViT输出中提取与文本相关的视觉信息
  3. 冻结的LLM: Q-Former输出的视觉token与文本token拼接后输入大语言模型进行生成
  4. 训练阶段: 先做视觉-语言表示学习(对比+匹配+生成),再做视觉-语言生成学习
class QFormer(nn.Module): """BLIP-2中的Querying Transformer""" def __init__(self, num_queries=32, d_model=768, num_layers=6): super().__init__() self.query_tokens = nn.Parameter(torch.randn(1, num_queries, d_model)) decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=12, batch_first=True) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers) def forward(self, image_features): """ image_features: (B, N_v, d) ViT输出的视觉特征 """ B = image_features.shape[0] queries = self.query_tokens.expand(B, -1, -1) # (B, 32, d) # Q-Former通过交叉注意力从视觉特征中提取信息 visual_queries = self.decoder( queries, image_features, memory_key_padding_mask=None ) # (B, 32, d) return visual_queries # 压缩后的视觉token,供LLM使用

6.3 SAM:Segment Anything

Meta发布的Segment Anything Model(SAM)是CV领域的"GPT-3时刻"。它基于提示分割(Promptable Segmentation)范式,通过一个强大的图像编码器(MAE预训练的ViT-H)轻量的提示编码器+掩码解码器,实现了对任意图像、任意物体的分割能力。

SAM的三组件架构

class SAM(nn.Module): """简化版Segment Anything Model""" def __init__(self): super().__init__() self.image_encoder = ViT(d_model=1280, depth=32, num_heads=16) self.prompt_encoder = PromptEncoder() self.mask_decoder = MaskDecoder() def forward(self, image, points=None, boxes=None, masks=None): """ image: (B, 3, H, W) points: (B, N, 2) 坐标点 boxes: (B, 4) 边界框 """ # 1. 图像编码(只需一次,可复用) image_embedding = self.image_encoder(image) # (B, d, H/16, W/16) # 2. 提示编码 sparse_embeddings = self.prompt_encoder(points, boxes) dense_embeddings = self.prompt_encoder(masks) # 3. 掩码解码 masks, iou_pred = self.mask_decoder( image_embedding, sparse_embeddings, dense_embeddings ) return masks, iou_pred @torch.no_grad() def segment_anything(self, image, prompts): """任意提示分割:支持点、框、文本prompt""" image_emb = self.image_encoder(image) all_masks = [] for prompt in prompts: prompt_emb = self.prompt_encoder(prompt) mask, iou = self.mask_decoder(image_emb, prompt_emb) all_masks.append(mask) return all_masks

SAM的核心突破

  • 提示分割范式: 用户通过点击、画框等方式交互式获取分割结果,无需针对特定数据集训练
  • SA-1B数据集: 1100万张图像、10亿个掩码的超级数据集,是SAM强大的基础
  • 数据标注革命: SAM+数据标注工具(如Label Studio)可大幅降低标注成本
  • 零样本泛化: 在未见过的图像和物体上表现出惊人的泛化能力
  • 模型系列: ViT-B/L/H三档,fps从3到50帧,覆盖不同部署场景

SAM 2:扩展到视频

2024年发布的SAM 2将可提示分割扩展到视频领域。其核心创新是记忆编码器(Memory Encoder)记忆注意力(Memory Attention):在逐帧处理时,将先前帧的预测结果和特征存入记忆库,后续帧通过注意力机制读取记忆,实现高效的视频目标分割。SAM 2在视频分割精度上超过此前所有SOTA方法。

七、Transformer高效化:轻量部署

7.1 高效化的挑战

标准Transformer的平方复杂度(O(N²))和大量参数使其在移动端和边缘设备上难以部署。高效化研究主要有三个方向:轻量自注意力机制(降低复杂度)、混合CNN-Transformer架构(融合两者优势)、结构重参数化(训练-部署解耦)。

7.2 EfficientViT:级联组注意力

EfficientViT(MIT, 2023)针对高分辨率视觉任务提出了级联组注意力(Cascaded Group Attention)机制。它将多头注意力的头分组,在不同阶段逐步细化特征表达,大幅减少了冗余计算。

class CascadedGroupAttention(nn.Module): """EfficientViT级联组注意力""" def __init__(self, d_model=128, n_heads=8): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads assert self.head_dim * n_heads == d_model # Q/K/V 线性投影 self.qkv = nn.Linear(d_model, d_model * 3) self.proj = nn.Linear(d_model, d_model) def forward(self, x, h_split): """ h_split: 每个stage的head数量列表 例如 h_split=[2, 2, 4] 表示三级级联 """ B, N, D = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim) q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] # (B, N, H, hd) # 级联处理: 逐组计算注意力并将结果传递到下一组 outputs = [] h_cum = 0 for hg in h_split: # 当前组内的qk计算 q_g = q[:, :, h_cum:h_cum+hg] k_g = k[:, :, h_cum:h_cum+hg] v_g = v[:, :, h_cum:h_cum+hg] attn = (q_g @ k_g.transpose(-2, -1)) / self.head_dim**0.5 attn = attn.softmax(dim=-1) out_g = attn @ v_g # (B, N, hg, hd) outputs.append(out_g) # 将当前组输出加到下一组的key上 if h_cum + hg < self.n_heads: k[:, :, h_cum+hg:h_cum+hg*2] += out_g.mean(dim=2, keepdim=True) h_cum += hg out = torch.cat(outputs, dim=2) out = out.reshape(B, N, D) return self.proj(out)

7.3 MobileViT & EdgeViT:混合CNN-Transformer

MobileViT(Apple, 2021)是轻量混合架构的代表作。它在MobileNetV2的倒残差块基础上,插入MobileViT块——将特征图划分为局部窗口,窗口内进行Transformer建模,窗口间通过卷积融合。这种设计既有CNN的高效局部处理能力,又有Transformer的全局依赖建模能力。

EdgeViT(2022)进一步引入分解的局部-全局自注意力(Decomposed Local-Global Self-Attention):局部注意力在小窗口内计算,全局注意力通过稀疏采样实现,两者结合达到接近ViT的性能但参数量和FLOPs大幅降低。

class MobileViTBlock(nn.Module): """MobileViT:混合CNN-Transformer块""" def __init__(self, d_model=96, patch_size=2, nhead=4): super().__init__() self.patch_size = patch_size # 局部卷积编码 self.local_conv = nn.Sequential( nn.Conv2d(d_model, d_model, 3, 1, 1, groups=d_model), nn.Conv2d(d_model, d_model, 1) ) # Transformer:建模全局关系 self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model, nhead, batch_first=True), num_layers=2 ) # 融合卷积 self.fusion = nn.Conv2d(2*d_model, d_model, 1) def forward(self, x): """ x: (B, C, H, W) """ skip = x # 1. 局部编码 x = self.local_conv(x) # (B, C, H, W) B, C, H, W = x.shape P = self.patch_size # 2. 将特征图拆分成patches,在patch内展开像素序列 # 形成 (B, n_h*n_w, P*P, C) 格式 x_patches = x.unfold(2, P, P).unfold(3, P, P) # 重排为 (B, n_h*n_w, P*P, C) # 在 P*P 维度上做Transformer B, C, nh, nw, _, _ = x_patches.shape x_flat = x_patches.permute(0, 2, 3, 4, 5, 1).reshape(B, nh*nw, P*P, C) x_flat = x_flat.reshape(B * nh * nw, P*P, C) x_trans = self.transformer(x_flat) # 全局自注意力 x_trans = x_trans.reshape(B, nh, nw, P, P, C).permute(0, 5, 1, 3, 2, 4) x_trans = x_trans.reshape(B, C, H, W) # 3. 与原始特征融合 out = self.fusion(torch.cat([x_trans, skip], dim=1)) return out
模型 参数量 ImageNet Top1 延迟(iPhone 12) 核心思想
MobileViT-S 5.6M 78.4% 1.8ms CNN + 局部窗口Transformer
EdgeViT-S 5.5M 79.4% 1.6ms 分解局部-全局注意力
EfficientViT-B0 3.4M 76.2% 1.2ms 级联组注意力
EfficientFormer-L1 12.3M 82.6% 3.4ms 纯MLP+线性注意力

高效化的核心设计原则

  • 混合架构优先: 纯Transformer在移动端性价比不高,CNN+Transformer混合效果最佳
  • 局部化注意力: 限制注意力的空间范围(窗口/局部采样)降低复杂度
  • 渐进式下采样: 早期用CNN快速降低分辨率,后期用Transformer做全局建模
  • 重参数化: 训练时复杂结构、推理时合并为简单卷积(如RepVGG思路)
  • 硬件感知设计: 考虑内存带宽、矩阵乘法加速器(NPU)特性

总结与展望

Transformer在计算机视觉中的应用已经走过五个年头,从最初的ViT在图像分类上追平CNN,到今天DETR、Mask2Former、CLIP、SAM等模型在检测、分割、多模态等各个方向全面超越CNN时代的SOTA,Transformer正在重塑CV的技术版图。

展望未来,视觉通用模型(视觉基础模型)将像NLP中的GPT一样,通过大规模预训练+任务级微调或in-context learning解决几乎所有视觉任务。SAM已经展示了这一趋势,能分割一切的能力远超此前任何一个特定任务模型。同时,视觉-语言统一模型(如Flamingo、GPT-4V)正打破CV和NLP的边界,无监督/自监督预训练、多模态对齐、感知与推理的结合将是下一阶段的核心课题。

推荐阅读与参考文献