Vision Transformer(ViT)

深度学习专题 · Transformer跨界计算机视觉

专题:深度学习系统学习

关键词:Vision Transformer, ViT, Patch Embedding, Swin Transformer, DeiT, DINO, 自注意力

一、概述:ViT核心思想

Vision Transformer(ViT)由Google Research团队在2020年提出(论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》),是将Transformer架构从自然语言处理引入计算机视觉领域的里程碑式工作。ViT的核心思想极其简洁:将一张图像分割成固定大小的Patch(图像块),将这些Patch线性投影为序列化的嵌入向量,加上位置编码后送入标准的Transformer编码器进行处理,最后通过[CLS]分类头完成图像分类任务。

在此之前,计算机视觉领域长期被卷积神经网络(CNN)主导。尽管有一些尝试将注意力机制引入CNN的工作(如SENet、CBAM、Non-local Networks),但它们都是在CNN框架内局部地使用注意力。ViT的革命性在于完全抛弃了卷积操作,仅依靠Transformer的自注意力机制来建模图像中任意两个Patch之间的全局依赖关系,证明了纯Transformer架构在视觉任务上的可行性。

ViT的提出打破了NLP和CV两大领域之间的方法壁垒,使得统一的多模态架构成为可能,深刻影响了后续CLIP、DALL-E、SAM等模型的架构设计。

ViT整体架构示意图:输入图像 → Patch分割 → Patch Embedding + 位置编码 → Transformer Encoder(L层)→ MLP Head → 分类输出

核心公式:给定输入图像 x ∈ ℝ^(H×W×C),将其分割为 N = HW/P² 个大小为 P×P 的Patch,每个Patch展平后线性投影为 D 维嵌入向量,加上可学习的位置编码后,构成Transformer编码器的输入序列。

二、图像Patch嵌入原理

ViT的第一步是将图像转换为一系列"视觉词汇"(Visual Words),这一过程被称为Patch Embedding。假设输入图像尺寸为224×224×3,Patch大小设为16×16,则图像被分割为(224/16)² = 196个Patch,每个Patch的原始维度为16×16×3 = 768。

2.1 Patch Embedding实现

在PyTorch中,Patch Embedding通常通过一个卷积层高效实现:使用kernel_size=16, stride=16, out_channels=D的Conv2d卷积,等价于对每个16×16的Patch做一次全连接投影。这种做法比手动切分再展平投影效率更高,而且数学上完全等价。

import torch import torch.nn as nn class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.num_patches = (img_size // patch_size) ** 2 self.patch_size = patch_size # 使用Conv2d实现Patch Embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # x: [B, 3, 224, 224] x = self.proj(x) # [B, 768, 14, 14] x = x.flatten(2) # [B, 768, 196] x = x.transpose(1, 2) # [B, 196, 768] return x # 测试 patch_embed = PatchEmbed(img_size=224, patch_size=16, embed_dim=768) x = torch.randn(4, 3, 224, 224) # batch=4 out = patch_embed(x) print(out.shape) # torch.Size([4, 196, 768])

上述代码输出张量的形状为[B, 196, 768],其中196是Patch数量,768是每个Patch嵌入后的特征维度。这与NLP中词嵌入(Word Embedding)的角色完全对应:196个Patch相当于句子中的196个Token,768维的嵌入向量相当于每个Token的词向量。这种类比是理解ViT的关键——图像被"翻译"成了Transformer能理解的语言。

与NLP词嵌入的类比:在BERT中,输入是"[CLS] 我 爱 自 然 语 言 处 理 [SEP]",每个词对应一个Token嵌入。在ViT中,输入是"[CLS] Patch_1 Patch_2 ... Patch_196",每个Patch对应一个视觉Token嵌入。Transformer编码器在NLP和CV两个领域的处理方式完全一致!

2.2 位置编码

由于自注意力机制本身是置换等变(Permutation Equivariant)的——即打乱输入顺序不会影响注意力计算结果——我们需要额外添加位置编码来保留Patch之间的空间位置信息。ViT使用可学习的1D位置编码(而非Transformer原始的正余弦编码),这是因为1D位置编码更简单且在实验中被证明效果相当。

class PositionalEncoding(nn.Module): def __init__(self, num_patches, embed_dim): super().__init__() # 可学习的位置编码参数 self.pos_embed = nn.Parameter( torch.randn(1, num_patches + 1, embed_dim) * 0.02 ) # +1 是因为额外添加了一个[CLS] token def forward(self, x): return x + self.pos_embed

三、ViT模型架构详解

ViT的整体架构遵循标准Transformer编码器设计。输入序列由三部分拼接而成:一个可学习的[CLS] Token(用于分类)、196个Patch Token嵌入、以及可学习的位置编码。整个序列维度为[1, 197, 768](其中197 = 1个[CLS] + 196个Patch)。

3.1 Transformer编码器层

每一层Transformer编码器包含两个核心子层:多头自注意力(MSA)前馈网络(MLP),每个子层前后都使用Layer Normalization(LN)和残差连接。

class TransformerEncoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) self.ln2 = nn.LayerNorm(embed_dim) mlp_hidden_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, embed_dim), nn.Dropout(dropout), ) def forward(self, x): # Pre-LN + 残差连接 x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] x = x + self.mlp(self.ln2(x)) return x

Pre-LN vs Post-LN:ViT使用Pre-Layer Normalization(在子层之前做归一化)而非原始Transformer的Post-LN。Pre-LN训练更稳定,允许更大的学习率,且不需要warmup阶段。

3.2 完整ViT模型

class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1): super().__init__() self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches # [CLS] Token self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) # 位置编码(+1 对应[CLS] token) self.pos_embed = nn.Parameter( torch.randn(1, num_patches + 1, embed_dim) * 0.02) self.pos_drop = nn.Dropout(dropout) # Transformer编码器堆叠 self.blocks = nn.Sequential(*[ TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) self.ln = nn.LayerNorm(embed_dim) # 分类头 self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) # [B, 196, 768] cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, 768] x = torch.cat([cls_tokens, x], dim=1) # [B, 197, 768] x = x + self.pos_embed # 加位置编码 x = self.pos_drop(x) x = self.blocks(x) # 经过12层Transformer x = self.ln(x) cls_final = x[:, 0] # 取[CLS]位置输出 logits = self.head(cls_final) # 分类 return logits # 实例化ViT-Base模型 model = VisionTransformer( img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=1000 ) total_params = sum(p.numel() for p in model.parameters()) print(f"ViT-Base参数量: {total_params / 1e6:.2f}M") # ~86M

ViT的三种标准配置如下表所示:

模型embed_dimdepthnum_heads参数量FLOPs
ViT-Base768121286M16.8G
ViT-Large10242416307M61.6G
ViT-Huge12803216632M134G

四、ViT vs CNN对比分析

ViT与CNN在设计哲学上存在着深刻的差异。理解这些差异有助于我们把握两种范式的本质特征和适用场景。

4.1 全局感受野 vs 局部感受野

CNN通过堆叠小卷积核(如3×3)逐层扩大感受野,浅层看到局部纹理,深层看到全局语义。而ViT在第一层就能通过自注意力机制建立任意两个Patch之间的直接连接,从一开始就拥有全局感受野。这意味着ViT更容易捕获图像中的长距离依赖关系,但也带来了更大的计算开销。

4.2 归纳偏置 vs 数据驱动

CNN内建了极强的归纳偏置(Inductive Bias):平移等变性(卷积核在图像各处共享)、局部连接性(每个神经元只连接局部区域)、层次化特征学习。这些先验知识使得CNN在数据量不足时仍能有效学习。而ViT几乎没有视觉相关的归纳偏置——它完全依赖大规模数据驱动来学习视觉特征。这正是ViT需要JFT-300M这样的大规模预训练数据才能超越CNN的根本原因。

ViT优势

  • 全局感受野,擅于捕获长距离依赖
  • 数据量充足时,性能远超CNN
  • 架构统一,易于多模态扩展
  • 自注意力可视化可解释性强

ViT劣势

  • 需要大规模预训练数据(>100M)
  • 计算复杂度为O(N²),高分辨率下昂贵
  • 缺乏局部先验,小数据易过拟合
  • 优化更困难,训练技巧要求高

4.3 计算复杂度分析

自注意力的计算复杂度为O(N²·D),其中N是Patch数量,D是嵌入维度。对于224×224图像和16×16 Patch,N = 196,计算量尚可接受。但随着输入分辨率提升,N呈二次增长,这使得ViT处理高分辨率图像时计算量爆炸式增长。相比之下,CNN的复杂度随输入尺寸线性增长。

def compute_complexity(img_size, patch_size, embed_dim): N = (img_size // patch_size) ** 2 # 自注意力复杂度: O(N^2 * D) msa_flops = 4 * N * embed_dim ** 2 + 2 * N ** 2 * embed_dim # MLP复杂度: O(N * D^2 * 4) mlp_flops = 8 * N * embed_dim ** 2 total_flops = msa_flops + mlp_flops return total_flops for resolution in [224, 384, 512, 1024]: flops = compute_complexity(resolution, 16, 768) print(f"分辨率 {resolution}x{resolution}: {flops/1e9:.2f} GFLOPs")
分辨率 224x224: 16.80 GFLOPs 分辨率 384x384: 118.46 GFLOPs 分辨率 512x512: 364.79 GFLOPs 分辨率 1024x1024: 5764.47 GFLOPs # 不可接受的爆炸式增长

五、ViT主要变体

ViT的成功催生了一系列重要的变体工作,这些工作在数据效率、训练策略、架构改进等方面做出了关键贡献。

5.1 DeiT:数据高效ViT

DeiT(Data-efficient Image Transformers)由Facebook AI在2021年提出,解决了ViT需要海量预训练数据的痛点。其核心创新在于知识蒸馏训练策略:使用一个强大的CNN教师网络(如RegNet)来指导ViT学生网络的训练,同时引入了一个专门的蒸馏Token(distillation token)。

class DeiT(VisionTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 额外的蒸馏Token self.dist_token = nn.Parameter( torch.randn(1, 1, kwargs['embed_dim']) * 0.02) # 蒸馏头 self.head_dist = nn.Linear(kwargs['embed_dim'], kwargs['num_classes']) def forward(self, x): # 在forward中同时处理[CLS] token和[Dist] token # 训练时使用软蒸馏 + 硬蒸馏的联合损失 pass # 完整实现类似ViT但包含蒸馏逻辑

DeiT的关键贡献:仅使用ImageNet-1K(1.2M图像)就能训练出与ViT相当的性能,无需JFT-300M。通过组合数据增强(RandAugment、MixUp、CutMix、Random Erasing)和知识蒸馏,DeiT-Ti(5M参数)达到72.2% top-1精度,DeiT-B(86M参数)达到81.8%。

5.2 DINO:自监督ViT

DINO(Emerging Properties in Self-Supervised Vision Transformers)展示了ViT在自监督学习范式下的惊人潜力。DINO使用教师-学生自蒸馏框架,不需要任何标注数据即可训练ViT。更重要的是,DINO训练出的ViT自注意力图自动学会了物体分割——这一语义分割能力在CNN中从未出现过。

# DINO核心:自蒸馏伪标签 def dino_loss(student_output, teacher_output, temperature=0.1): # 中心化和锐化 teacher_output = teacher_output - teacher_output.mean(dim=0) student_out = student_output / temperature teacher_out = teacher_output / 0.04 # teacher温度低,分布更尖锐 # 交叉熵损失 loss = - (teacher_out.softmax(dim=-1) * student_out.log_softmax(dim=-1)).sum(dim=-1).mean() return loss # DINO训练出的ViT自注意力图可直接用于无监督语义分割 # 这对CNN来说是前所未见的涌现属性(Emerging Property)

5.3 Swin Transformer

Swin Transformer(2021年CVPR最佳论文)由微软亚洲研究院提出,是目前最具影响力的ViT变体之一。它引入了分层移位窗口(Shifted Window)机制,将自注意力计算限制在局部窗口内(窗口大小通常为7×7),并通过窗口移位实现跨窗口信息交互。

class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.window_size = window_size # (7, 7) self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 # 相对位置编码表 (2*7-1) x (2*7-1) self.relative_position_bias_table = nn.Parameter( torch.randn((2*window_size[0]-1) * (2*window_size[1]-1), num_heads)) def forward(self, x, mask=None): # 窗口内的自注意力计算 # 复杂度 O(M^2) 而非 O(N^2),其中 M=窗口Token数, N=全部Token数 pass # 详细实现涉及窗口切分、移位、掩码等复杂逻辑 # Swin Transformer的核心创新总结 # 1. 分层特征图:4x → 8x → 16x → 32x 下采样(类似CNN金字塔) # 2. 窗口自注意力:计算复杂度从O(N²)降到O(N·M²),M< # 3. 移位窗口:相邻层窗口错位,实现跨窗口信息流通 # 4. 相对位置编码:提供比绝对位置编码更好的平移不变性

Swin Transformer的优势:① 计算复杂度从ViT的O(N²)降至O(N·M²),M为固定窗口尺寸(如7×7),使高分辨率输入成为可能;② 分层结构使其天然适合作为各类视觉任务的骨干网络(分类、检测、分割),可替代ResNet、ResNeXt等CNN骨干;③ 在COCO目标检测和ADE20K语义分割等下游任务上大幅超越CNN基线。

下表总结了主要ViT变体的关键特性对比:

模型年份核心创新ImageNet Top-1参数量
ViT-B/162020纯Transformer架构77.9% (JFT-300M)86M
DeiT-B2021知识蒸馏 + 强数据增强81.8% (IN-1K)86M
Swin-B2021移位窗口 + 分层结构83.5% (IN-22K)88M
DINO2021自蒸馏自监督 ViT78.2% (无监督)86M
ConvNeXt2022现代化CNN + ViT设计理念83.8% (IN-22K)89M
EfficientViT2023级联分组注意力 + 轻量化83.5%31M

六、ViT训练技巧

ViT的训练比CNN更加敏感,需要一系列精心设计的训练策略才能发挥最佳性能。以下是经过实践验证的关键技巧。

6.1 优化器与学习率调度

ViT推荐使用AdamW优化器(而非CNN常用的SGD),权重衰减设为0.1~0.3。学习率通常采用余弦衰减调度(Cosine Decay),配合线性热身(Linear Warmup)策略。基础学习率约为3e-3(batch size=4096时),可按照线性缩放规则调整:lr = base_lr × (batch_size / 4096)。

import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR # ViT推荐优化器配置 optimizer = optim.AdamW( model.parameters(), lr=3e-3, weight_decay=0.3, # 大权重衰减 betas=(0.9, 0.999), eps=1e-8 ) # 余弦衰减 + 10000步热身 scheduler = CosineAnnealingLR(optimizer, T_max=300, eta_min=1e-6)

6.2 正则化策略

ViT的正则化组合比CNN更为激进:

# 随机深度(Stochastic Depth / DropPath)实现 class DropPath(nn.Module): def __init__(self, drop_prob=0.0): super().__init__() self.drop_prob = drop_prob def forward(self, x): if not self.training or self.drop_prob == 0.0: return x keep_prob = 1.0 - self.drop_prob shape = (x.shape[0],) + (1,)*(x.ndim - 1) mask = x.new_empty(shape).bernoulli_(keep_prob) mask.div_(keep_prob) return x * mask # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

6.3 数据增强策略

ViT论文中采用了一系列强大的数据增强技术:

# MixUp 数据增强 def mixup(x, y, alpha=0.8): lam = np.random.beta(alpha, alpha) batch_size = x.size(0) index = torch.randperm(batch_size) mixed_x = lam * x + (1.0 - lam) * x[index] mixed_y = (lam * y + (1.0 - lam) * y[index]) return mixed_x, mixed_y # CutMix 数据增强 def cutmix(x, y, alpha=1.0): lam = np.random.beta(alpha, alpha) batch_size = x.size(0) index = torch.randperm(batch_size) bbx1, bby1 = np.random.randint(0, x.size(2)), np.random.randint(0, x.size(3)) bbx2, bby2 = np.random.randint(bbx1, x.size(2)), np.random.randint(bby1, x.size(3)) x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(2) * x.size(3))) return x, lam * y + (1.0 - lam) * y[index]

七、ViT性能与可扩展性

ViT展示了优异的模型可扩展性(Scaling Property):随着模型规模和数据量的增加,其性能呈现稳定的对数线性增长,且尚未出现饱和趋势。这一特性与NLP领域的Transformer一脉相承,也是ViT被认为是"视觉基础模型"理想架构的重要原因。

7.1 模型缩放规律

通过在ImageNet上的系统性实验,ViT论文揭示了以下关键的缩放规律:

关键发现:ViT缺乏CNN的归纳偏置,因此需要更多数据来学习视觉结构。当数据量达到100M级别时,ViT的优势才能充分显现。这一发现直接推动了后续DeiT(数据高效训练)和Swin Transformer(引入局部先验)等改进工作。

7.2 下游任务迁移

ViT在ImageNet上预训练后,可以通过微调迁移到各种下游视觉任务:在VTAB基准测试(19个视觉任务)上,ViT-H/14达到了77.63%的平均准确率;在目标检测任务中,结合FPN结构,ViT可以作为强大的骨干网络;在语义分割中,采用SETR(Segmentation Transformer)结构,ViT在ADE20K上达到50.28% mIoU。

# 迁移学习微调示例 # 加载预训练ViT并替换分类头 import timm # 使用timm库加载预训练ViT model = timm.create_model('vit_base_patch16_224', pretrained=True) # 替换分类头进行下游任务微调 model.head = nn.Linear(model.embed_dim, num_classes=10) # CIFAR-10 # 小学习率微调,冻结前几层加速训练 for name, param in model.named_parameters(): if 'blocks.0' in name or 'blocks.1' in name: param.requires_grad = False optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=0.05 )

八、ViT推理与可视化实战

以下代码示例演示了如何使用预训练的ViT模型进行图像分类推理,并可视化自注意力图以直观理解模型关注区域。

import timm import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt import numpy as np # 1. 加载预训练ViT模型 model = timm.create_model('vit_base_patch16_224', pretrained=True) model.eval() # 配置:注册hook获取注意力权重 attention_maps = [] def hook_fn(module, input, output): # output[1] 包含注意力权重 [B, num_heads, N, N] attention_maps.append(output[1].detach()) # 注册hook到每个Transformer层的注意力模块 handles = [] for block in model.blocks: handles.append(block.attn.register_forward_hook(hook_fn)) # 2. 图像预处理 transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = Image.open('cat.jpg') input_tensor = transform(img).unsqueeze(0) # 3. 前向推理 with torch.no_grad(): logits = model(input_tensor) probs = torch.softmax(logits, dim=-1) # 获取ImageNet标签 with open('imagenet_labels.txt') as f: labels = [line.strip() for line in f.readlines()] pred_idx = logits.argmax(dim=-1).item() print(f"预测类别: {labels[pred_idx]}") print(f"置信度: {probs[0][pred_idx].item():.4f}") # 4. 可视化最后一层[CLS]的注意力图 # 取最后一层,第一个head,[CLS] token对所有patch的注意力 last_attn = attention_maps[-1] # [1, 12, 197, 197] cls_attn = last_attn[0, :, 0, 1:] # [12, 196] cls_attn = cls_attn.mean(dim=0) # [196] 在所有头上平均 # 将注意力图reshape为14x14并放大到224x224 attn_map = cls_attn.reshape(14, 14).numpy() attn_map = Image.fromarray(attn_map).resize((224, 224), Image.BICUBIC) # 可视化叠加 plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(img) plt.title('Original Image') plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(img) plt.imshow(attn_map, alpha=0.6, cmap='jet') plt.title('Attention Map (CLS token)') plt.axis('off') plt.show() # 清理hooks for handle in handles: handle.remove()

注意力可视化要点:ViT最后一层[CLS] token对其他Patch的注意力权重可以生成一张注意力热力图(Attention Map),热力值高的区域代表模型在分类决策时的"重点关注区域"。这一特性使得ViT天然具有比CNN更强的可解释性。

九、局限性与未来方向

尽管ViT取得了巨大成功,但仍存在明显的局限性,也指引着后续研究的演进方向。

9.1 当前局限性

9.2 未来演进方向

关键趋势:CV领域正从"CNN vs ViT"的范式之争走向"混合架构"的新阶段。代表性的工作包括:ConvNeXt(现代化CNN吸收ViT设计理念)、MaxViT(多轴注意力)、EfficientViT(轻量化ViT)、FastViT(高推理速度ViT)、以及Vision Mamba(状态空间模型替代自注意力)。CNN和ViT的设计思想正在深度融合,取长补短。

十、核心要点总结

十一、进一步思考

ViT的出现不仅是计算机视觉领域的一次技术革新,更代表着一种范式转变:从手工设计的局部操作(卷积)转向数据驱动的全局建模(自注意力)。这一趋势与NLP领域从RNN/LSTM到Transformer的演进高度一致,暗示着统一的"基础模型"架构可能是人工智能发展的必然方向。

在实践中,选择ViT还是CNN应根据具体场景权衡:如果拥有充足的数据和计算资源(如100M+图像、多GPU集群),ViT及其变体通常能带来更好的性能;如果数据有限或需要边缘部署,精心设计的CNN(如ConvNeXt、EfficientNet)仍然是更务实的选择。未来,CNN和ViT的混合架构以及基于状态空间模型(如Mamba)的新一代视觉架构,很可能成为计算机视觉领域的主流范式。