← 返回深度学习目录
← 返回学习笔记首页
专题: 深度学习系统学习
关键词: Vision Transformer, ViT, Patch Embedding, Swin Transformer, DeiT, DINO, 自注意力
一、概述:ViT核心思想
Vision Transformer(ViT)由Google Research团队在2020年提出(论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》),是将Transformer架构从自然语言处理引入计算机视觉领域的里程碑式工作 。ViT的核心思想极其简洁:将一张图像分割成固定大小的Patch(图像块),将这些Patch线性投影为序列化的嵌入向量,加上位置编码后送入标准的Transformer编码器进行处理,最后通过[CLS]分类头完成图像分类任务。
在此之前,计算机视觉领域长期被卷积神经网络(CNN)主导。尽管有一些尝试将注意力机制引入CNN的工作(如SENet、CBAM、Non-local Networks),但它们都是在CNN框架内局部地使用注意力。ViT的革命性在于完全抛弃了卷积操作 ,仅依靠Transformer的自注意力机制来建模图像中任意两个Patch之间的全局依赖关系,证明了纯Transformer架构在视觉任务上的可行性。
ViT的提出打破了NLP和CV两大领域之间的方法壁垒,使得统一的多模态架构成为可能,深刻影响了后续CLIP、DALL-E、SAM等模型的架构设计。
ViT整体架构示意图:输入图像 → Patch分割 → Patch Embedding + 位置编码 → Transformer Encoder(L层)→ MLP Head → 分类输出
核心公式: 给定输入图像 x ∈ ℝ^(H×W×C),将其分割为 N = HW/P² 个大小为 P×P 的Patch,每个Patch展平后线性投影为 D 维嵌入向量,加上可学习的位置编码后,构成Transformer编码器的输入序列。
二、图像Patch嵌入原理
ViT的第一步是将图像转换为一系列"视觉词汇"(Visual Words),这一过程被称为Patch Embedding。假设输入图像尺寸为224×224×3,Patch大小设为16×16,则图像被分割为(224/16)² = 196个Patch,每个Patch的原始维度为16×16×3 = 768。
2.1 Patch Embedding实现
在PyTorch中,Patch Embedding通常通过一个卷积层高效实现:使用kernel_size=16, stride=16, out_channels=D的Conv2d卷积,等价于对每个16×16的Patch做一次全连接投影。这种做法比手动切分再展平投影效率更高,而且数学上完全等价。
import torch
import torch.nn as nn
class PatchEmbed (nn.Module):
def __init__ (self , img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super ().__init__()
self .num_patches = (img_size // patch_size) ** 2
self .patch_size = patch_size
# 使用Conv2d实现Patch Embedding
self .proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward (self , x):
# x: [B, 3, 224, 224]
x = self .proj(x) # [B, 768, 14, 14]
x = x.flatten(2) # [B, 768, 196]
x = x.transpose(1, 2) # [B, 196, 768]
return x
# 测试
patch_embed = PatchEmbed(img_size=224, patch_size=16, embed_dim=768)
x = torch.randn(4 , 3 , 224 , 224 ) # batch=4
out = patch_embed(x)
print (out.shape) # torch.Size([4, 196, 768])
上述代码输出张量的形状为[B, 196, 768],其中196是Patch数量,768是每个Patch嵌入后的特征维度。这与NLP中词嵌入(Word Embedding)的角色完全对应:196个Patch相当于句子中的196个Token,768维的嵌入向量相当于每个Token的词向量 。这种类比是理解ViT的关键——图像被"翻译"成了Transformer能理解的语言。
与NLP词嵌入的类比: 在BERT中,输入是"[CLS] 我 爱 自 然 语 言 处 理 [SEP]",每个词对应一个Token嵌入。在ViT中,输入是"[CLS] Patch_1 Patch_2 ... Patch_196",每个Patch对应一个视觉Token嵌入。Transformer编码器在NLP和CV两个领域的处理方式完全一致!
2.2 位置编码
由于自注意力机制本身是置换等变 (Permutation Equivariant)的——即打乱输入顺序不会影响注意力计算结果——我们需要额外添加位置编码来保留Patch之间的空间位置信息。ViT使用可学习的1D位置编码(而非Transformer原始的正余弦编码),这是因为1D位置编码更简单且在实验中被证明效果相当。
class PositionalEncoding (nn.Module):
def __init__ (self , num_patches, embed_dim):
super ().__init__()
# 可学习的位置编码参数
self .pos_embed = nn.Parameter(
torch.randn(1 , num_patches + 1 , embed_dim) * 0.02
)
# +1 是因为额外添加了一个[CLS] token
def forward (self , x):
return x + self .pos_embed
三、ViT模型架构详解
ViT的整体架构遵循标准Transformer编码器设计。输入序列由三部分拼接而成:一个可学习的[CLS] Token(用于分类)、196个Patch Token嵌入、以及可学习的位置编码。整个序列维度为[1, 197, 768](其中197 = 1个[CLS] + 196个Patch)。
3.1 Transformer编码器层
每一层Transformer编码器包含两个核心子层:多头自注意力(MSA) 和前馈网络(MLP) ,每个子层前后都使用Layer Normalization(LN)和残差连接。
class TransformerEncoderLayer (nn.Module):
def __init__ (self , embed_dim, num_heads, mlp_ratio=4.0 , dropout=0.1 ):
super ().__init__()
self .ln1 = nn.LayerNorm(embed_dim)
self .attn = nn.MultiheadAttention(embed_dim, num_heads,
dropout=dropout, batch_first=True )
self .ln2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int (embed_dim * mlp_ratio)
self .mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward (self , x):
# Pre-LN + 残差连接
x = x + self .attn(self .ln1(x), self .ln1(x), self .ln1(x))[0 ]
x = x + self .mlp(self .ln2(x))
return x
Pre-LN vs Post-LN: ViT使用Pre-Layer Normalization(在子层之前做归一化)而非原始Transformer的Post-LN。Pre-LN训练更稳定,允许更大的学习率,且不需要warmup阶段。
3.2 完整ViT模型
class VisionTransformer (nn.Module):
def __init__ (self , img_size=224 , patch_size=16 ,
in_chans=3 , num_classes=1000 ,
embed_dim=768 , depth=12 ,
num_heads=12 , mlp_ratio=4.0 , dropout=0.1 ):
super ().__init__()
self .patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self .patch_embed.num_patches
# [CLS] Token
self .cls_token = nn.Parameter(torch.randn(1 , 1 , embed_dim) * 0.02 )
# 位置编码(+1 对应[CLS] token)
self .pos_embed = nn.Parameter(
torch.randn(1 , num_patches + 1 , embed_dim) * 0.02 )
self .pos_drop = nn.Dropout(dropout)
# Transformer编码器堆叠
self .blocks = nn.Sequential(*[
TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range (depth)
])
self .ln = nn.LayerNorm(embed_dim)
# 分类头
self .head = nn.Linear(embed_dim, num_classes)
def forward (self , x):
B = x.shape[0 ]
x = self .patch_embed(x) # [B, 196, 768]
cls_tokens = self .cls_token.expand(B, -1 , -1 ) # [B, 1, 768]
x = torch.cat([cls_tokens, x], dim=1 ) # [B, 197, 768]
x = x + self .pos_embed # 加位置编码
x = self .pos_drop(x)
x = self .blocks(x) # 经过12层Transformer
x = self .ln(x)
cls_final = x[:, 0 ] # 取[CLS]位置输出
logits = self .head(cls_final) # 分类
return logits
# 实例化ViT-Base模型
model = VisionTransformer(
img_size=224 , patch_size=16 ,
embed_dim=768 , depth=12 , num_heads=12 ,
num_classes=1000
)
total_params = sum (p.numel() for p in model.parameters())
print (f"ViT-Base参数量: {total_params / 1e6:.2f}M" ) # ~86M
ViT的三种标准配置如下表所示:
模型 embed_dim depth num_heads 参数量 FLOPs
ViT-Base 768 12 12 86M 16.8G
ViT-Large 1024 24 16 307M 61.6G
ViT-Huge 1280 32 16 632M 134G
四、ViT vs CNN对比分析
ViT与CNN在设计哲学上存在着深刻的差异。理解这些差异有助于我们把握两种范式的本质特征和适用场景。
4.1 全局感受野 vs 局部感受野
CNN通过堆叠小卷积核(如3×3)逐层扩大感受野,浅层看到局部纹理,深层看到全局语义。而ViT在第一层就能通过自注意力机制建立任意两个Patch之间的直接连接,从一开始就拥有全局感受野 。这意味着ViT更容易捕获图像中的长距离依赖关系,但也带来了更大的计算开销。
4.2 归纳偏置 vs 数据驱动
CNN内建了极强的归纳偏置 (Inductive Bias):平移等变性(卷积核在图像各处共享)、局部连接性(每个神经元只连接局部区域)、层次化特征学习。这些先验知识使得CNN在数据量不足时仍能有效学习。而ViT几乎没有视觉相关的归纳偏置——它完全依赖大规模数据驱动 来学习视觉特征。这正是ViT需要JFT-300M这样的大规模预训练数据才能超越CNN的根本原因。
ViT优势
全局感受野,擅于捕获长距离依赖
数据量充足时,性能远超CNN
架构统一,易于多模态扩展
自注意力可视化可解释性强
ViT劣势
需要大规模预训练数据(>100M)
计算复杂度为O(N²),高分辨率下昂贵
缺乏局部先验,小数据易过拟合
优化更困难,训练技巧要求高
4.3 计算复杂度分析
自注意力的计算复杂度为O(N²·D),其中N是Patch数量,D是嵌入维度。对于224×224图像和16×16 Patch,N = 196,计算量尚可接受。但随着输入分辨率提升,N呈二次增长,这使得ViT处理高分辨率图像时计算量爆炸式增长。相比之下,CNN的复杂度随输入尺寸线性增长。
def compute_complexity (img_size, patch_size, embed_dim):
N = (img_size // patch_size) ** 2
# 自注意力复杂度: O(N^2 * D)
msa_flops = 4 * N * embed_dim ** 2 + 2 * N ** 2 * embed_dim
# MLP复杂度: O(N * D^2 * 4)
mlp_flops = 8 * N * embed_dim ** 2
total_flops = msa_flops + mlp_flops
return total_flops
for resolution in [224 , 384 , 512 , 1024 ]:
flops = compute_complexity(resolution, 16 , 768 )
print (f"分辨率 {resolution}x{resolution}: {flops/1e9:.2f} GFLOPs" )
分辨率 224x224: 16.80 GFLOPs
分辨率 384x384: 118.46 GFLOPs
分辨率 512x512: 364.79 GFLOPs
分辨率 1024x1024: 5764.47 GFLOPs # 不可接受的爆炸式增长
五、ViT主要变体
ViT的成功催生了一系列重要的变体工作,这些工作在数据效率、训练策略、架构改进等方面做出了关键贡献。
5.1 DeiT:数据高效ViT
DeiT(Data-efficient Image Transformers)由Facebook AI在2021年提出,解决了ViT需要海量预训练数据的痛点。其核心创新在于知识蒸馏 训练策略:使用一个强大的CNN教师网络(如RegNet)来指导ViT学生网络的训练,同时引入了一个专门的蒸馏Token(distillation token)。
class DeiT (VisionTransformer):
def __init__ (self , *args, **kwargs):
super ().__init__(*args, **kwargs)
# 额外的蒸馏Token
self .dist_token = nn.Parameter(
torch.randn(1 , 1 , kwargs['embed_dim' ]) * 0.02 )
# 蒸馏头
self .head_dist = nn.Linear(kwargs['embed_dim' ], kwargs['num_classes' ])
def forward (self , x):
# 在forward中同时处理[CLS] token和[Dist] token
# 训练时使用软蒸馏 + 硬蒸馏的联合损失
pass # 完整实现类似ViT但包含蒸馏逻辑
DeiT的关键贡献: 仅使用ImageNet-1K(1.2M图像)就能训练出与ViT相当的性能,无需JFT-300M。通过组合数据增强(RandAugment、MixUp、CutMix、Random Erasing)和知识蒸馏,DeiT-Ti(5M参数)达到72.2% top-1精度,DeiT-B(86M参数)达到81.8%。
5.2 DINO:自监督ViT
DINO(Emerging Properties in Self-Supervised Vision Transformers)展示了ViT在自监督学习范式下的惊人潜力。DINO使用教师-学生自蒸馏 框架,不需要任何标注数据即可训练ViT。更重要的是,DINO训练出的ViT自注意力图自动学会了物体分割 ——这一语义分割能力在CNN中从未出现过。
# DINO核心:自蒸馏伪标签
def dino_loss (student_output, teacher_output, temperature=0.1 ):
# 中心化和锐化
teacher_output = teacher_output - teacher_output.mean(dim=0 )
student_out = student_output / temperature
teacher_out = teacher_output / 0.04 # teacher温度低,分布更尖锐
# 交叉熵损失
loss = - (teacher_out.softmax(dim=-1 ) *
student_out.log_softmax(dim=-1 )).sum(dim=-1 ).mean()
return loss
# DINO训练出的ViT自注意力图可直接用于无监督语义分割
# 这对CNN来说是前所未见的涌现属性(Emerging Property)
5.3 Swin Transformer
Swin Transformer(2021年CVPR最佳论文)由微软亚洲研究院提出,是目前最具影响力的ViT变体之一。它引入了分层移位窗口 (Shifted Window)机制,将自注意力计算限制在局部窗口内(窗口大小通常为7×7),并通过窗口移位实现跨窗口信息交互。
class WindowAttention (nn.Module):
def __init__ (self , dim, window_size, num_heads):
super ().__init__()
self .window_size = window_size # (7, 7)
self .num_heads = num_heads
self .scale = (dim // num_heads) ** -0.5
# 相对位置编码表 (2*7-1) x (2*7-1)
self .relative_position_bias_table = nn.Parameter(
torch.randn((2 *window_size[0 ]-1 ) * (2 *window_size[1 ]-1 ), num_heads))
def forward (self , x, mask=None ):
# 窗口内的自注意力计算
# 复杂度 O(M^2) 而非 O(N^2),其中 M=窗口Token数, N=全部Token数
pass # 详细实现涉及窗口切分、移位、掩码等复杂逻辑
# Swin Transformer的核心创新总结
# 1. 分层特征图:4x → 8x → 16x → 32x 下采样(类似CNN金字塔)
# 2. 窗口自注意力:计算复杂度从O(N²)降到O(N·M²),M<
# 3. 移位窗口:相邻层窗口错位,实现跨窗口信息流通
# 4. 相对位置编码:提供比绝对位置编码更好的平移不变性
Swin Transformer的优势: ① 计算复杂度从ViT的O(N²)降至O(N·M²),M为固定窗口尺寸(如7×7),使高分辨率输入成为可能;② 分层结构使其天然适合作为各类视觉任务的骨干网络(分类、检测、分割),可替代ResNet、ResNeXt等CNN骨干;③ 在COCO目标检测和ADE20K语义分割等下游任务上大幅超越CNN基线。
下表总结了主要ViT变体的关键特性对比:
模型 年份 核心创新 ImageNet Top-1 参数量
ViT-B/16 2020 纯Transformer架构 77.9% (JFT-300M) 86M
DeiT-B 2021 知识蒸馏 + 强数据增强 81.8% (IN-1K) 86M
Swin-B 2021 移位窗口 + 分层结构 83.5% (IN-22K) 88M
DINO 2021 自蒸馏自监督 ViT 78.2% (无监督) 86M
ConvNeXt 2022 现代化CNN + ViT设计理念 83.8% (IN-22K) 89M
EfficientViT 2023 级联分组注意力 + 轻量化 83.5% 31M
六、ViT训练技巧
ViT的训练比CNN更加敏感,需要一系列精心设计的训练策略才能发挥最佳性能。以下是经过实践验证的关键技巧。
6.1 优化器与学习率调度
ViT推荐使用AdamW 优化器(而非CNN常用的SGD),权重衰减设为0.1~0.3。学习率通常采用余弦衰减调度(Cosine Decay),配合线性热身(Linear Warmup)策略。基础学习率约为3e-3(batch size=4096时),可按照线性缩放规则调整:lr = base_lr × (batch_size / 4096)。
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
# ViT推荐优化器配置
optimizer = optim.AdamW(
model.parameters(),
lr=3e-3 ,
weight_decay=0.3 , # 大权重衰减
betas=(0.9 , 0.999 ),
eps=1e-8
)
# 余弦衰减 + 10000步热身
scheduler = CosineAnnealingLR(optimizer, T_max=300 , eta_min=1e-6 )
6.2 正则化策略
ViT的正则化组合比CNN更为激进:
随机深度(Stochastic Depth) :以概率p随机丢弃Transformer层,通常p从0到0.5线性增加。这相当于深度维度的Dropout,对ViT训练至关重要。
Dropout + DropPath :在MLP层使用Dropout(rate=0.1),同时在残差连接中使用DropPath。
标签平滑(Label Smoothing) :epsilon设为0.1,缓解过自信预测。
梯度裁剪(Gradient Clipping) :将全局梯度范数裁剪到1.0,防止训练不稳定。
# 随机深度(Stochastic Depth / DropPath)实现
class DropPath (nn.Module):
def __init__ (self , drop_prob=0.0 ):
super ().__init__()
self .drop_prob = drop_prob
def forward (self , x):
if not self .training or self .drop_prob == 0.0 :
return x
keep_prob = 1.0 - self .drop_prob
shape = (x.shape[0 ],) + (1 ,)*(x.ndim - 1 )
mask = x.new_empty(shape).bernoulli_(keep_prob)
mask.div_(keep_prob)
return x * mask
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 )
6.3 数据增强策略
ViT论文中采用了一系列强大的数据增强技术:
RandAugment :随机选择2~3种增强操作(旋转、平移、对比度调整等),增强幅度可调。
MixUp :将两张图像按比例混合,标签也按相同比例混合。λ ∼ Beta(α, α),α通常为0.8。
CutMix :将一张图像的矩形区域裁剪并粘贴到另一张图像上,标签按面积比例混合。
Random Erasing :随机擦除图像中的矩形区域,迫使模型关注更具判别性的特征。
# MixUp 数据增强
def mixup (x, y, alpha=0.8 ):
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0 )
index = torch.randperm(batch_size)
mixed_x = lam * x + (1.0 - lam) * x[index]
mixed_y = (lam * y + (1.0 - lam) * y[index])
return mixed_x, mixed_y
# CutMix 数据增强
def cutmix (x, y, alpha=1.0 ):
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0 )
index = torch.randperm(batch_size)
bbx1, bby1 = np.random.randint(0 , x.size(2 )), np.random.randint(0 , x.size(3 ))
bbx2, bby2 = np.random.randint(bbx1, x.size(2 )), np.random.randint(bby1, x.size(3 ))
x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(2 ) * x.size(3 )))
return x, lam * y + (1.0 - lam) * y[index]
七、ViT性能与可扩展性
ViT展示了优异的模型可扩展性 (Scaling Property):随着模型规模和数据量的增加,其性能呈现稳定的对数线性增长,且尚未出现饱和趋势。这一特性与NLP领域的Transformer一脉相承,也是ViT被认为是"视觉基础模型"理想架构的重要原因。
7.1 模型缩放规律
通过在ImageNet上的系统性实验,ViT论文揭示了以下关键的缩放规律:
数据缩放 :在ImageNet-1K(1.3M图像)上,ViT性能略低于同等规模的ResNet。但在ImageNet-21K(14M图像)上,ViT可超越ResNet。在JFT-300M(3亿图像)上,ViT大幅领先。
模型缩放 :从ViT-B到ViT-L再到ViT-H,性能持续提升。ViT-H/14在JFT-300M预训练后,ImageNet top-1精度达到88.55%。
计算量缩放 :同等FLOPs下,ViT在足够数据上始终优于CNN。但ViT为达到SOTA所需的FLOPs通常高于CNN。
关键发现: ViT缺乏CNN的归纳偏置,因此需要更多数据来学习视觉结构。当数据量达到100M级别时,ViT的优势才能充分显现。这一发现直接推动了后续DeiT(数据高效训练)和Swin Transformer(引入局部先验)等改进工作。
7.2 下游任务迁移
ViT在ImageNet上预训练后,可以通过微调迁移到各种下游视觉任务:在VTAB基准测试(19个视觉任务)上,ViT-H/14达到了77.63%的平均准确率;在目标检测任务中,结合FPN结构,ViT可以作为强大的骨干网络;在语义分割中,采用SETR(Segmentation Transformer)结构,ViT在ADE20K上达到50.28% mIoU。
# 迁移学习微调示例
# 加载预训练ViT并替换分类头
import timm
# 使用timm库加载预训练ViT
model = timm.create_model('vit_base_patch16_224' , pretrained=True )
# 替换分类头进行下游任务微调
model.head = nn.Linear(model.embed_dim, num_classes=10 ) # CIFAR-10
# 小学习率微调,冻结前几层加速训练
for name, param in model.named_parameters():
if 'blocks.0' in name or 'blocks.1' in name:
param.requires_grad = False
optimizer = optim.AdamW(
filter (lambda p: p.requires_grad, model.parameters()),
lr=1e-4 , weight_decay=0.05
)
八、ViT推理与可视化实战
以下代码示例演示了如何使用预训练的ViT模型进行图像分类推理,并可视化自注意力图以直观理解模型关注区域。
import timm
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# 1. 加载预训练ViT模型
model = timm.create_model('vit_base_patch16_224' , pretrained=True )
model.eval()
# 配置:注册hook获取注意力权重
attention_maps = []
def hook_fn (module, input, output):
# output[1] 包含注意力权重 [B, num_heads, N, N]
attention_maps.append(output[1 ].detach())
# 注册hook到每个Transformer层的注意力模块
handles = []
for block in model.blocks:
handles.append(block.attn.register_forward_hook(hook_fn))
# 2. 图像预处理
transform = T.Compose([
T.Resize(256 ),
T.CenterCrop(224 ),
T.ToTensor(),
T.Normalize(mean=[0.485 , 0.456 , 0.406 ],
std=[0.229 , 0.224 , 0.225 ]),
])
img = Image.open('cat.jpg' )
input_tensor = transform(img).unsqueeze(0 )
# 3. 前向推理
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=-1 )
# 获取ImageNet标签
with open('imagenet_labels.txt' ) as f:
labels = [line.strip() for line in f.readlines()]
pred_idx = logits.argmax(dim=-1 ).item()
print (f"预测类别: {labels[pred_idx]}" )
print (f"置信度: {probs[0][pred_idx].item():.4f}" )
# 4. 可视化最后一层[CLS]的注意力图
# 取最后一层,第一个head,[CLS] token对所有patch的注意力
last_attn = attention_maps[-1 ] # [1, 12, 197, 197]
cls_attn = last_attn[0 , :, 0 , 1 :] # [12, 196]
cls_attn = cls_attn.mean(dim=0 ) # [196] 在所有头上平均
# 将注意力图reshape为14x14并放大到224x224
attn_map = cls_attn.reshape(14 , 14 ).numpy()
attn_map = Image.fromarray(attn_map).resize((224 , 224 ), Image.BICUBIC)
# 可视化叠加
plt.figure(figsize=(10 , 5 ))
plt.subplot(1 , 2 , 1 )
plt.imshow(img)
plt.title('Original Image' )
plt.axis('off' )
plt.subplot(1 , 2 , 2 )
plt.imshow(img)
plt.imshow(attn_map, alpha=0.6 , cmap='jet' )
plt.title('Attention Map (CLS token)' )
plt.axis('off' )
plt.show()
# 清理hooks
for handle in handles:
handle.remove()
注意力可视化要点: ViT最后一层[CLS] token对其他Patch的注意力权重可以生成一张注意力热力图(Attention Map),热力值高的区域代表模型在分类决策时的"重点关注区域"。这一特性使得ViT天然具有比CNN更强的可解释性。
九、局限性与未来方向
尽管ViT取得了巨大成功,但仍存在明显的局限性,也指引着后续研究的演进方向。
9.1 当前局限性
计算效率 :自注意力的O(N²)复杂度限制了ViT处理高分辨率图像的能力。虽然Swin Transformer的窗口注意力在一定程度上缓解了这一问题,但在视频理解等需要处理大量Token的场景中仍然面临挑战。
数据依赖 :相较于CNN,ViT仍然需要更多的训练数据和更强的正则化,在中小规模数据集上的表现不如精心调优的CNN(如ConvNeXt)。
局部纹理建模 :ViT的全局注意力并不天然擅长捕获局部纹理和边缘信息,这在小样本学习和细粒度分类任务中可能成为瓶颈。
推理速度 :ViT的推理速度(throughput)通常慢于同等参数量的CNN,特别是在边缘设备上的部署仍存在困难。
9.2 未来演进方向
关键趋势: CV领域正从"CNN vs ViT"的范式之争走向"混合架构"的新阶段。代表性的工作包括:ConvNeXt(现代化CNN吸收ViT设计理念)、MaxViT(多轴注意力)、EfficientViT(轻量化ViT)、FastViT(高推理速度ViT)、以及Vision Mamba(状态空间模型替代自注意力)。CNN和ViT的设计思想正在深度融合,取长补短。
十、核心要点总结
ViT核心思想 :将图像分割为固定大小的Patch,线性投影为嵌入向量,加上位置编码后送入标准Transformer编码器,通过[CLS] Token完成分类。核心公式:图像 = 196个视觉Token。
Patch Embedding :通过Conv2d(kernel=16, stride=16)高效实现16×16 Patch的线性投影。输出[B, 196, 768]等价于NLP中的词嵌入序列。
位置编码 :使用可学习的1D位置编码,弥补自注意力机制置换等变特性所丢失的空间位置信息。
全局感受野 :第一层即可建立任意Patch间的直接连接,相比之下CNN需要通过堆叠层数逐步扩大感受野。
数据驱动 vs 归纳偏置 :ViT几乎没有视觉先验,完整依赖大规模数据学习;CNN内建平移等变性和局部连接等归纳偏置。
DeiT :通过知识蒸馏和强数据增强(MixUp/CutMix/RandAugment),使ViT仅在ImageNet-1K上即可达到SOTA。
Swin Transformer :引入移位窗口机制,将自注意力复杂度从O(N²)降至O(N·M²),并建立分层金字塔结构。
训练技巧 :AdamW优化器、余弦学习率衰减、随机深度、梯度裁剪、大权重衰减(0.3)是ViT训练的关键配方。
缩放规律 :ViT性能随模型规模和数据量呈稳定的对数线性增长,数据量达到100M级别时优势明显。
自注意力可视化 :ViT的[CLS] Token注意力图天然具有语义可解释性,可直接用于无监督物体定位。
十一、进一步思考
ViT的出现不仅是计算机视觉领域的一次技术革新,更代表着一种范式转变 :从手工设计的局部操作(卷积)转向数据驱动的全局建模(自注意力)。这一趋势与NLP领域从RNN/LSTM到Transformer的演进高度一致,暗示着统一的"基础模型"架构可能是人工智能发展的必然方向。
在实践中,选择ViT还是CNN应根据具体场景权衡:如果拥有充足的数据和计算资源(如100M+图像、多GPU集群),ViT及其变体通常能带来更好的性能;如果数据有限或需要边缘部署,精心设计的CNN(如ConvNeXt、EfficientNet)仍然是更务实的选择。未来,CNN和ViT的混合架构 以及基于状态空间模型(如Mamba)的新一代视觉架构,很可能成为计算机视觉领域的主流范式。