← 返回深度学习目录
← 返回学习笔记首页
变分自编码器(VAE)
深度学习专题 · 概率生成模型
专题: 深度学习系统学习
关键词: 深度学习, VAE, 变分自编码器, 重参数化, ELBO, KL散度, Beta-VAE, VQ-VAE, 生成模型
一、概述
变分自编码器(Variational Autoencoder, VAE) 由 Kingma 和 Welling 于 2013 年提出,是一种深度生成模型(Deep Generative Model),将变分推断(Variational Inference)与神经网络相结合。与传统的自编码器(Autoencoder, AE)不同,VAE 不是一个确定性模型,而是一个概率生成模型:它学习数据的潜在概率分布,从而能够从潜在空间中采样生成全新的数据样本。
VAE 的核心思想可以概括为:给定观测数据 x ,我们希望学习一个潜在变量 z 的概率分布,使得从 z 中采样后能够重构出 x 。VAE 同时学习两个网络:一个编码器(Encoder)将输入 x 映射到潜在空间中的分布参数(均值 μ 和方差 σ² ),一个解码器(Decoder)从潜在变量 z 重构出 x 。
VAE 三大核心贡献: (1)将变分推断引入深度生成模型训练;(2)提出重参数化技巧(Reparameterization Trick)解决采样不可微问题;(3)以 ELBO(证据下界)作为优化目标,平衡重构质量与潜在空间的正则化。
VAE 的数学动机源于概率图模型。假设数据 x 由未观测的隐变量 z 生成,我们希望通过最大化边缘对数似然 log p(x) 来学习生成过程。然而直接最大化 log p(x) = log ∫ p(x|z)p(z)dz 通常是不可行的,因为积分高维且复杂。VAE 通过引入一个识别模型(即编码器)q_φ(z|x) 来近似真实后验 p_θ(z|x) ,并利用变分推断优化 ELBO。
二、从自编码器到变分自编码器
2.1 传统自编码器的局限
传统自编码器由一个编码器 z = f(x) 和一个解码器 x' = g(z) 组成,通过最小化重构损失 L = ||x - x'||² 进行训练。训练完成后,编码器将输入压缩为一个确定的低维向量。这种确定性编码存在几个关键问题:潜在空间不连续,缺乏规则性,导致在潜在空间中两个相邻点解码出的结果可能完全不同;无法从潜在空间采样生成新数据;对噪声和异常值敏感。
AE 的核心缺陷: 编码器的输出是单一确定向量,而非概率分布。这意味着潜在空间没有被正则化——模型可以学会将每个输入映射到潜在空间中任意分散的点,而不关心这些点之间的空间结构。
2.2 VAE 的概率视角
VAE 的编码器不再输出一个确定的潜在向量,而是输出一个概率分布。具体来说,编码器输出潜在变量的均值 μ 和对数方差 log σ² (而非方差本身,以确保数值稳定性)。然后从 N(μ, σ²) 中采样得到 z ,再送入解码器。
VAE 假设先验分布 p(z) = N(0, I) ,即标准正态分布。通过 KL 散度迫使编码器输出的后验分布 q_φ(z|x) 逼近标准正态分布。这样做的好处是:潜在空间被"锚定"在原点附近,形成连续、光滑的流形结构,使得在潜在空间中任意两个点之间插值都能产生有意义的输出。
直觉理解: AE 将每张图片压缩成一个"代码",而 VAE 将每张图片压缩成一个"代码区域"。从同一区域中采样的不同代码解码后应得到相似的图片,这赋予了 VAE 生成新数据的能力。—— 这就是为什么 VAE 是生成模型而 AE 不是。
三、VAE 损失函数与 ELBO 推导
3.1 ELBO 证据下界
VAE 的优化目标是最大化对数似然的证据下界(Evidence Lower Bound, ELBO)。推导从边缘对数似然开始:
log p(x) = log ∫ p(x|z) p(z) dz
= log ∫ q(z|x) * p(x|z) * p(z) / q(z|x) dz
= log E_{z~q(z|x)} [ p(x|z) * p(z) / q(z|x) ]
>= E_{z~q(z|x)} [ log p(x|z) + log p(z) - log q(z|x) ] (Jensen 不等式)
= E_{z~q(z|x)} [ log p(x|z) ] - KL( q(z|x) || p(z) )
因此 ELBO 由两项组成:
重构项 E_{z~q(z|x)}[log p_θ(x|z)] :衡量解码器从潜在变量 z 重构输入 x 的质量。在连续数据中通常建模为高斯分布(对应 MSE 损失),在离散数据(如图像像素 0-255)中建模为伯努利分布(对应交叉熵损失)。
KL 散度正则项 KL(q_φ(z|x) || p(z)) :衡量编码器输出的近似后验与先验分布之间的差异。该项迫使编码器产生接近标准正态分布 N(0, I) 的潜在表示。
# ELBO 的数值形式(批量计算)
# 对于 mini-batch 中的每个样本:
# loss = reconstruction_loss + beta * KL_divergence
# 重构损失(假设高斯分布输出,即 MSE):
reconstruction_loss = ||x - x_recon||²
# KL 散度(闭式解,因为均为高斯分布):
KL_divergence = -0.5 * sum(1 + log(σ²) - μ² - σ²)
3.2 KL 散度的闭式推导
由于 q_φ(z|x) = N(μ, σ²) 和 p(z) = N(0, I) 均为高斯分布,KL 散度存在闭式解:
KL(N(μ, σ²) || N(0, I))
= ∫ N(μ, σ²) * log[ N(μ, σ²) / N(0, I) ] dz
= -0.5 * sum( 1 + log(σ²) - μ² - σ² )
推导过程(对 D 维潜在空间的求和):
KL = 0.5 * Σ_{i=1}^{D} ( σ_i² + μ_i² - 1 - log(σ_i²) )
推导的关键步骤:两个多元高斯分布的 KL 散度公式为 KL(N₁ || N₂) = 0.5 * [tr(Σ₂⁻¹Σ₁) + (μ₂ - μ₁)ᵀΣ₂⁻¹(μ₂ - μ₁) - k + ln(det Σ₂ / det Σ₁)] 。代入 Σ₁ = diag(σ²), Σ₂ = I, μ₁ = μ, μ₂ = 0 ,得到上述简洁形式。
3.3 Beta-VAE:调整正则化权重
Higgins 等人(2017)提出了 β -VAE,通过引入超参数 β 来控制 KL 正则项的权重:
L_β-VAE = E_{z~q(z|x)}[ log p(x|z) ] - β * KL( q(z|x) || p(z) )
当 β > 1 时,模型学习到更加解耦(Disentangled)的潜在表示,即潜在空间的每个维度对应数据中独立的可解释因子(如旋转、缩放、颜色等)。当 β < 1 时,模型更专注于重构质量,但潜在空间的解耦性降低。典型的 β 取值范围为 1-10。
β-VAE 的直觉: 更大的 KL 惩罚迫使潜在分布更加接近标准正态分布,这意味着模型不能依赖编码器输出的分布携带过多信息,必须迫使潜变量的每个维度独立且信息量最大化,从而促进解耦表示学习。
四、重参数化技巧
4.1 不可微的采样操作
在 VAE 的前向传播中,编码器输出 μ 和 σ 后,我们需要从 N(μ, σ²) 中采样 z 再传入解码器。然而,采样操作本身是不可微的(即 z = sample(N(μ, σ²)) 这一操作没有定义梯度),导致梯度无法从解码器反向传播到编码器。
4.2 重参数化技巧的核心思想
重参数化技巧(Reparameterization Trick)是最核心的贡献之一。它将随机采样过程分解为确定性变换加外部随机噪声:
# 不重参数化(梯度无法回传):
z = sample_from(N(μ, σ²)) # ❌ 采样操作不可微
# 重参数化(梯度可回传):
ε = sample_from(N(0, I)) # ✅ 从标准正态分布采样
z = μ + σ ⊙ ε # ✅ 线性变换,可微!
这里 ε 是独立于模型参数的随机噪声,⊙ 表示逐元素乘法。由于 ε 来自固定的标准正态分布,采样过程被移出了计算图,μ 和 σ 可以正常接收梯度:
∂z/∂μ = 1 (梯度直接通过 μ 回传)
∂z/∂σ = ε (梯度通过 ε 回传)
重参数化技巧的本质: 将"从参数分布中采样"这一概率操作,转化为"从标准分布中采样 + 确定性参数变换"的组合。这样我们就绕过了采样操作的不可微壁垒,使得随机层可以无缝接入神经网络的反向传播框架中。
4.3 梯度绕路图解
# 前向传播的计算图示意
输入 x
↓
编码器网络 (权重 W_enc)
↓
输出 μ(x) 和 log_σ²(x)
↓
σ(x) = exp(0.5 * log_σ²(x)) # 确保 σ > 0
ε = randn_like(μ) # 从 N(0, I) 采样 [不依赖模型参数]
z = μ + σ * ε # [梯度可在此回传]
↓
解码器网络 (权重 W_dec)
↓
重构输出 x_recon
# 梯度路径:
# dL/dμ = dL/dz * dz/dμ = dL/dz * 1 ← 梯度正常回传
# dL/dσ = dL/dz * dz/dσ = dL/dz * ε ← 通过 ε 回传
# dL/dW_enc = f(dL/dμ, dL/dσ) ← 编码器参数正常更新
重参数化技巧不仅适用于高斯分布,理论上可用于任何可通过位置-尺度变换(Location-Scale Family)表示的分布,包括拉普拉斯分布、柯西分布等。对于离散分布,则需要使用 Gumbel-Softmax 技巧(Jang et al., 2016)等替代方案。
五、VAE vs AE 对比分析
对比维度 自编码器 (AE) 变分自编码器 (VAE)
潜在空间 确定性(一个输入对应一个点) 概率性(一个输入对应一个分布)
潜在空间连续性 不连续、无正则化约束 连续、光滑(受 KL 正则化约束)
生成能力 不能生成新数据(仅能重构) 可从先验采样生成全新数据
插值能力 潜在空间中插值结果无意义 线性插值可产生语义平滑过渡
流形学习 不保证学到数据流形 学习数据的概率流形结构
损失函数 MSE 重构损失 ELBO = 重构 + KL 正则化
训练稳定性 相对稳定 对 β 超参数敏感
潜在空间维数 通常较低(20-50) 可高可低(典型 2-256)
对噪声鲁棒性 较差 较好(概率建模天然抗噪)
核心区别一句话总结: AE 将输入压缩为一个点 ,VAE 将输入展开为一个区域 。正是这种"点→区域"的转变,赋予了 VAE 生成能力、插值能力和流形学习能力,同时也带来了更严格的数学框架(变分推断、KL 散度、重参数化)。
5.1 插值能力对比
在 VAE 的潜在空间中,由于 KL 正则化迫使编码器输出的分布逼近标准正态分布,潜在空间被"压缩"成一个以原点为中心、各向同性的连续流形。在隐空间中取两个点 z₁ 和 z₂ 进行线性插值 z(t) = (1-t)z₁ + tz₂, t ∈ [0,1] ,解码后对应语义上的平滑过渡(如人脸从表情A渐变到表情B)。而 AE 的潜在空间可能非常不规则,插值往往产生毫无意义的中间结果。
实践意义: 插值能力是生成模型质量的重要衡量标准。一个好的生成模型应该能够生成"在训练数据之间"的样本,这表明模型学到了数据的连续流形结构而非简单的记忆。VAE 的这一能力广泛应用于图像编辑、风格迁移和数据增强。
六、VAE 的核心应用
6.1 图像生成
VAE 最直接的应用是图像生成。通过在标准正态分布中采样 z ~ N(0, I) 并输入解码器,即可生成逼真的图像。虽然早期 VAE 生成的图像比 GAN 模糊,但后续改进(如 VQ-VAE、NVAE)在图像质量上已接近 GAN 的水平,同时训练更稳定。
6.2 异常检测
VAE 可以通过重构概率进行异常检测。正常数据经过 VAE 后重构误差较低,而异常数据的重构误差显著较高。与 AE 相比,VAE 的概率性质使其重构误差具有统计意义——我们可以计算 p(x|z) 作为异常分数,而不是简单的 MSE。
# 基于 VAE 的异常检测流程
1. 在正常数据上训练 VAE
2. 对每个测试样本 x:
a. 编码得到 μ, σ
b. 多次采样 z ~ N(μ, σ²),计算重构概率
c. 异常分数 = E[||x - recon||²] (或负对数似然)
3. 设定阈值,判定异常
6.3 表示学习与解耦表示
β-VAE 可学习解耦的潜在表示,其中每个潜在维度编码数据的一个独立变化因素。例如在 CelebA 人脸数据集上,不同的维度可能独立控制姿态角度、肤色、眼镜佩戴、背景颜色等属性。这种可解释的表示在可控生成、域适应和零样本推理中极为有用。
6.4 半监督学习
VAE 可以扩展到半监督学习场景。通过将潜在变量分为类别变量和连续风格变量,模型可以利用有标签数据学习类别判别,同时利用无标签数据学习生成过程。代表性方法包括 M2 模型(Kingma et al., 2014)和 Skip-Gram VAE。
6.5 文本生成与分子设计
VAE 也广泛应用于文本生成(如优化文本潜在空间中的插值)和分子设计(用 VAE 在分子结构空间中学习连续表示,通过优化连续潜变量来生成具有特定化学属性的分子)。bowman2016 的基于 RNN 的 VAE 是文本 VAE 的开创性工作。
七、PyTorch VAE 完整实现
以下是一个完整的 PyTorch VAE 实现,包含模型定义、训练循环和生成采样。代码使用 MNIST 手写数字数据集作为示例。
7.1 模型定义
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# ---------- VAE 模型 ----------
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Sigmoid(), # 输出归一化到 [0,1]
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
# 重参数化技巧
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) # 从 N(0,I) 采样
return mu + eps * std # z ~ N(μ, σ²)
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar, z
# ---------- 损失函数 ----------
@staticmethod
def loss_function(recon_x, x, mu, logvar, beta=1.0):
# 重构损失(伯努利分布假设 -> 二元交叉熵)
recon_loss = F.binary_cross_entropy(
recon_x, x.view(-1, 784), reduction='sum'
)
# KL 散度(闭式解)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# ELBO = - (recon_loss + beta * kl_loss)
return recon_loss + beta * kl_loss, recon_loss, kl_loss
7.2 训练循环
# ---------- 数据加载 ----------
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# ---------- 训练配置 ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(latent_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 50
beta = 1.0 # beta-VAE 中的 β
# ---------- 训练循环 ----------
model.train()
for epoch in range(1, epochs + 1):
total_loss = 0.0
total_recon = 0.0
total_kl = 0.0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar, z = model(data)
loss, recon_loss, kl_loss = VAE.loss_function(
recon_batch, data, mu, logvar, beta
)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_recon += recon_loss.item()
total_kl += kl_loss.item()
avg_loss = total_loss / len(train_loader.dataset)
avg_recon = total_recon / len(train_loader.dataset)
avg_kl = total_kl / len(train_loader.dataset)
print(f'Epoch {epoch:3d} | Loss: {avg_loss:.2f} | '
f'Recon: {avg_recon:.2f} | KL: {avg_kl:.4f}')
7.3 从先验生成样本
# ---------- 从标准正态分布采样生成 ----------
model.eval()
with torch.no_grad():
# 从先验 p(z) = N(0, I) 采样
z_sample = torch.randn(64, 20).to(device) # 64 个随机潜在向量
gen_images = model.decode(z_sample).view(-1, 1, 28, 28)
# 潜在空间插值
z1 = torch.randn(1, 20).to(device)
z2 = torch.randn(1, 20).to(device)
alphas = torch.linspace(0, 1, steps=10).to(device)
interp_images = []
for alpha in alphas:
z_interp = (1 - alpha) * z1 + alpha * z2
img = model.decode(z_interp).view(1, 28, 28)
interp_images.append(img)
# interp_images 即为从 z1 到 z2 的插值结果
7.4 监控训练:KL 散度退火
# ---------- KL 退火:解决 KL 消失问题 ----------
# 在训练文本 VAE 时,KL 项经常坍缩到 0(称为 posterior collapse)
# 解决方法:逐步增加 KL 权重
def kl_anneal(epoch, total_epochs, cycle=10):
"""周期性 KL 退火"""
progress = (epoch % cycle) / cycle
return min(1.0, progress * 2)
# 训练时:
beta = kl_anneal(epoch, epochs)
loss, recon_loss, kl_loss = VAE.loss_function(
recon_batch, data, mu, logvar, beta
)
以上代码提供了一个可独立运行的 VAE 完整实现。训练结束后,可以通过在潜空间中采样或插值来生成新的手写数字图像。
八、条件 VAE(Conditional VAE, CVAE)
条件 VAE(CVAE) 是 VAE 的条件版本(Sohn et al., 2015),在编码器和解码器中都引入条件变量 c (如类别标签、描述文本等),使得生成过程可以受条件控制。
8.1 CVAE 的核心公式
CVAE 的优化目标变为条件 ELBO:
log p(x|c) >= E_{z~q(z|x,c)}[ log p(x|z,c) ] - KL( q(z|x,c) || p(z|c) )
其中:
- 编码器变为 q_φ(z|x,c):给定 x 和条件 c 推断 z
- 解码器变为 p_θ(x|z,c):给定 z 和条件 c 生成 x
- 先验也可依赖于条件:p(z|c)(通常仍简化为 N(0,I))
8.2 CVAE 的 PyTorch 实现要点
class CVAE(nn.Module):
"""条件 VAE:以类别标签为条件"""
def __init__(self, latent_dim=20, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
# 编码器:输入 x (784) + one-hot c (10) -> 潜变量参数
self.encoder = nn.Sequential(
nn.Linear(784 + num_classes, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
# 解码器:潜变量 z (20) + one-hot c (10) -> 重构 x
self.decoder = nn.Sequential(
nn.Linear(latent_dim + num_classes, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Sigmoid(),
)
def forward(self, x, labels):
# 将标签转为 one-hot
c = F.one_hot(labels, num_classes=self.num_classes).float()
x_flat = x.view(-1, 784)
# 编码器:拼接 x 和 c
enc_input = torch.cat([x_flat, c], dim=1)
h = self.encoder(enc_input)
mu, logvar = self.fc_mu(h), self.fc_logvar(h)
# 重参数化
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
# 解码器:拼接 z 和 c
dec_input = torch.cat([z, c], dim=1)
recon = self.decoder(dec_input)
return recon, mu, logvar, z
def generate(self, labels, num_per_class=10):
"""按类别生成图像"""
self.eval()
with torch.no_grad():
batch_labels = torch.repeat_interleave(
torch.arange(self.num_classes, device=next(self.parameters()).device),
num_per_class
)
c = F.one_hot(batch_labels, num_classes=self.num_classes).float()
z = torch.randn(len(batch_labels), self.latent_dim).to(c.device)
dec_input = torch.cat([z, c], dim=1)
return self.decoder(dec_input).view(-1, 1, 28, 28)
# generate() 返回按类别排列的生成图像,每类 num_per_class 张
CVAE 的一个重要优势是解决模式坍塌问题:通过条件变量,模型必须为每个类别都分配合理的生成概率,从而确保生成结果的多样性。CVAE 在条件图像生成、文本到图像合成、可控数据增强等场景中广泛应用。
九、VQ-VAE:离散潜在空间
VQ-VAE(Vector Quantized VAE) 由 van den Oord 等人于 2017 年提出,是 VAE 家族中的一个重要变体。与标准 VAE 使用连续潜在空间不同,VQ-VAE 使用离散潜在空间,通过向量量化(Vector Quantization)将编码器的输出映射到最近的码本(Codebook)向量上。
9.1 VQ-VAE 的动机
标准 VAE 的连续潜在空间假设后验是高斯分布,但许多自然数据(如文本、语音)本质上具有离散结构。VQ-VAE 通过引入离散潜在表示,解决了以下问题:
避免了"后验坍缩"(Posterior Collapse)问题
更适合与自回归模型(如 PixelCNN、Transformer)结合
生成的图像比标准 VAE 更清晰,接近 GAN 的质量
潜在表示具有自然的离散语义,更适合推理和规划
9.2 VQ-VAE 的核心架构
# VQ-VAE 的三组件:
# 1. 编码器:将输入 x 映射为连续特征图 z_e(x)
# 2. 码本 Codebook:可学习的 K 个嵌入向量 {e₁, e₂, ..., e_K}
# 3. 解码器:从量化后的 z_q(x) 重构 x
# 前向传播流程:
z_e(x) = Encoder(x) # 连续特征
z_q(x) = e_k, where k = argmin_j ||z_e(x) - e_j||₂ # 最近邻量化
x_hat = Decoder(z_q(x)) # 重构
# 损失函数:
L = ||x - x_hat||² # 重构损失
+ ||sg[z_e(x)] - e||² # 码本学习(codebook loss)
+ β * ||z_e(x) - sg[e]||² # 编码器承诺(commitment loss)
# sg = stop_gradient(停止梯度传播)
# VQ-VAE 的损失拆解:
# - 重构损失:通过解码器和编码器更新(端到端训练)
# - 码本损失:只更新码本向量 e,使其接近编码器输出
# - 承诺损失:只更新编码器,使其承诺使用码本(β 通常取 0.25)
9.3 VQ-VAE 的 PyTorch 实现要点
class VectorQuantizer(nn.Module):
"""VQ-VAE 的向量量化层"""
def __init__(self, num_embeddings=512, embedding_dim=64, beta=0.25):
super().__init__()
self.K = num_embeddings # 码本大小
self.D = embedding_dim # 码本向量维度
self.beta = beta
# 可学习的码本嵌入矩阵 [K, D]
self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1/self.K, 1/self.K)
def forward(self, z_e):
# z_e: [B, D, H, W] 编码器输出的特征图
# 展平为 [B*H*W, D]
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, self.D)
# 计算与所有码本向量的距离: ||z_e - e||²
distances = torch.cdist(z_e_flat, self.embedding.weight) # [N, K]
# 最近邻编码: argmin 得到编码索引
encoding_indices = torch.argmin(distances, dim=1) # [N]
# 量化: 选择最近的码本向量
z_q_flat = self.embedding(encoding_indices) # [N, D]
# 重塑回 [B, D, H, W]
z_q = z_q_flat.view_as(z_e)
# 损失项
codebook_loss = F.mse_loss(z_q_flat.detach(), z_e_flat) # 码本损失
commitment_loss = F.mse_loss(z_q_flat, z_e_flat.detach()) # 承诺损失
vq_loss = codebook_loss + self.beta * commitment_loss
# 直通估计器(Straight-Through Estimator):
# 前向传播使用量化后的 z_q,反向传播梯度绕过量化操作
z_q = z_e + (z_q - z_e).detach()
return z_q, vq_loss, encoding_indices.view_as(z_e[:,0,:,:])
9.4 VQ-VAE-2:层级离散编码
VQ-VAE-2(Razavi et al., 2019)引入了层级架构:分成 top-level(全局结构)和 bottom-level(局部细节)两个层次的 VQ-VAE。顶部编码器捕获全局信息(如物体类别、姿态),底部编码器捕获局部纹理细节。这种层级设计使 VQ-VAE-2 能够生成长程一致的百万像素级高清图像。
VQ-VAE 的现代地位: VQ-VAE 的离散潜在空间使其天然适合与 Transformer 结合(如 DALL-E、VQGAN 等大型多模态模型)。这些模型先使用 VQ-VAE 将图像压缩为离散编码序列,然后训练 Transformer 对编码序列进行自回归建模,实现高质量的文本到图像生成。
十、VAE 与其他生成模型的比较
特性 VAE GAN Flow-based Diffusion
训练稳定性 高 低(对抗训练不稳定) 高 高
样本质量 中等 高 中等 高
多样性 高 中等(可能模式坍塌) 高 高
潜在空间可解释性 高 中等 高 低
似然计算 下界(ELBO) 无法计算 精确 下界
推断网络 显式(编码器) 无 隐式 需反向过程
推理速度 快(单前向) 快(单前向) 快(单前向) 慢(多步采样)
主要瓶颈 样本模糊 训练不稳定 架构约束大 采样速度慢
不同生成模型各有优劣。VAE 的优势在于训练稳定、具有显式的概率推断框架、潜在空间可解释性强。在实际应用中,VAE 常用于需要可解释潜在表示的场景(如医药分子设计、科学数据分析),或与其它模型结合使用(如 VQ-VAE + Transformer)。近年来最先进的图像生成方案(如 Stable Diffusion 等)本质上是 VAE(用于压缩)+ Diffusion 模型的组合架构。
十一、核心要点总结
1. VAE 的定义: 一种概率生成模型,通过编码器将输入映射为潜在概率分布,再通过解码器从潜在采样中重构输入。优化目标是最大化 ELBO(证据下界)。
2. 重参数化技巧: 将不可微的随机采样分解为 z = μ + σ·ε (其中 ε ~ N(0,I) ),使梯度可以通过 μ 和 σ 反向传播,完成编码器-解码器联合训练。
3. ELBO 分解: 损失函数由重构损失(衡量生成质量)和 KL 散度(约束潜在空间接近标准正态分布)组成。β-VAE 通过调节 β 平衡二者,促进解耦表示学习。
4. VAE vs AE: AE 是确定性模型,潜在空间无正则化,不能生成新数据;VAE 是概率模型,潜在空间连续光滑,支持从先验采样生成新数据。
5. CVAE: 在条件变量(如类别标签)控制下生成数据,广泛应用于条件图像生成和可控数据增强。
6. VQ-VAE: 引入离散潜在空间,通过向量量化和码本学习获得高质量的离散表示,是现代大型多模态生成模型(如 DALL-E、VQGAN)的基础组件。
7. 应用领域: 图像生成、异常检测、表示学习/解耦、半监督学习、分子设计、文本生成等。
十二、进一步思考
VAE 提供了将概率推断与深度学习结合的优雅框架。其核心思想——学习数据的潜在概率分布而非确定性的编码——已经在各个领域产生了深远影响。
值得深入探索的方向包括:
后验坍缩问题: 在文本 VAE 中,当解码器能力过强时,KL 项会坍缩到零,模型退化为普通语言模型。解决方案包括 KL 退火、Free Bits、InfoVAE 等。
层级 VAE: 引入多个层次的潜在变量(如 NVAE、HVAE),在高层次捕获全局结构,低层次捕获局部细节,显著提升生成质量。
VAE + Diffusion 融合: 现代扩散模型(如 Stable Diffusion)使用 VAE 将图像压缩到低维潜在空间,再在潜空间中进行扩散/去噪,获得高清生成结果。
VQ-VAE + Transformer: 将图像的离散编码序列交由 Transformer 建模(如 DALL-E、Parti),实现文本到图像的全新生成范式。
学习建议: 理解 VAE 的关键在于把握三个核心公式:ELBO 分解公式(理解优化目标)、KL 闭式解公式(理解编码器训练)、重参数化公式 z = μ + σ·ε(理解梯度流)。这三个公式理解透彻了,VAE 的大局就已掌握。下一步建议动手实现一个完整的 VAE 模型,在 MNIST 上观察重构和生成效果,然后逐步扩展到 β-VAE、CVAE 和 VQ-VAE。