← 返回深度学习目录
← 返回学习笔记首页
生成对抗网络(GAN)
深度学习专题 · 对抗训练驱动的生成模型
专题: 深度学习系统学习
关键词: 深度学习, GAN, 生成对抗网络, DCGAN, 模式崩溃, CycleGAN, StyleGAN, 图像生成, 对抗训练
一、生成对抗网络概述
生成对抗网络(Generative Adversarial Network, GAN)由 Ian Goodfellow 等人于 2014 年提出,被誉为"深度学习领域最令人振奋的想法之一"。GAN 的核心思想是通过两个神经网络的对抗训练来学习数据分布:一个生成器(Generator)负责从随机噪声中合成逼真样本,一个判别器(Discriminator)负责区分真实样本和生成样本。二者在零和博弈中相互竞争、共同进化,最终生成器能够产生以假乱真的数据。
GAN 的出现从根本上改变了生成式模型的研究范式。在此之前,生成模型主要依赖变分自编码器(VAE)和自回归模型,这些模型虽然理论基础扎实,但生成的样本往往不够清晰锐利。GAN 通过对抗训练机制,能够生成视觉质量极高的图像,在图像生成、超分辨率、风格迁移、数据增强等领域取得了突破性成果。
核心思想: 生成器 G 接收随机噪声 z 生成假样本 G(z),判别器 D 同时接收真实样本 x 和假样本 G(z),输出样本为"真实"的概率。G 的目标是最大化 D 判断错误的概率,D 的目标是正确区分真假样本。这种 Min-Max 博弈最终达到纳什均衡。
1.1 GAN 的基本架构
GAN 由两个独立的神经网络构成:生成器(Generator)和判别器(Discriminator)。生成器以随机噪声向量 z 作为输入,通过多层神经网络将其映射到数据空间,输出与真实数据形状相同的张量。判别器则是一个二分类网络,接收图像输入并输出一个介于 0 到 1 之间的标量,表示输入为真实样本的概率。两个网络在训练过程中交替优化,形成对抗关系。
从博弈论的角度看,GAN 的训练过程是一个二人零和博弈。生成器试图最小化价值函数 V(D,G),判别器试图最大化 V(D,G)。当双方都达到最优策略时,博弈处于纳什均衡状态——此时生成器产生的数据分布完美拟合真实数据分布,判别器无法区分真假,输出恒为 0.5。
"GAN 是监督学习与无监督学习的桥梁——它用监督学习的损失函数(真假判别)来实现无监督学习的目标(学习数据分布)。" —— Ian Goodfellow
二、GAN 数学原理
2.1 Minimax 博弈与价值函数
GAN 的目标函数可以形式化为以下 Minimax 博弈:
min_G max_D V(D,G) = E_x~pdata[log D(x)] + E_z~pz[log(1 - D(G(z)))]
其中 pdata 是真实数据分布,pz 是先验噪声分布(通常为高斯分布或均匀分布)。判别器 D(x) 输出 x 来自真实数据的概率。从判别器的角度看,它希望最大化正确分类真实样本(log D(x))和正确拒绝假样本(log(1-D(G(z))))的对数似然。从生成器的角度看,它希望最小化判别器正确拒绝假样本的概率,即最小化 log(1-D(G(z)))。
2.2 纳什均衡的理论分析
当生成器 G 固定时,最优判别器为:
D*(x) = pdata(x) / (pdata(x) + pg(x))
其中 pg 是生成器产生的数据分布。将 D* 代入价值函数,可以得到:
C(G) = -log(4) + 2 * JSD(pdata || pg)
其中 JSD 是 Jensen-Shannon 散度。因此 GAN 等价于最小化真实分布和生成分布之间的 JS 散度。当 pdata = pg 时,JSD 为 0,价值函数达到全局最小值 -log(4)。此时 D*(x) = 0.5,即判别器完全无法区分真假样本。
理论洞见: GAN 的对抗训练本质是在隐式地最小化真实分布与生成分布之间的 JS 散度。与 VAE 优化 ELBO(证据下界)不同,GAN 不依赖显式的变分下界,从而避免了 VAE 中常见的生成模糊问题。
2.3 非饱和损失函数(Non-Saturating Loss)
原始 GAN 中使用 min log(1-D(G(z))) 作为生成器损失,但在实际训练中,当判别器表现良好时,log(1-D(G(z))) 的梯度会趋于饱和,导致生成器学习缓慢。Goodfellow 提出了非饱和损失函数,将生成器的优化目标改为 max log(D(G(z))),即最大化判别器对假样本判断为真的概率。这一改进显著缓解了早期训练中的梯度问题,成为后续 GAN 训练的标准做法。
# GAN 的两种生成器损失对比
# 原始损失(饱和梯度):
L_G_saturating = -E_z[log(1 - D(G(z)))]
# 非饱和损失(改进版):
L_G_nonsaturating = -E_z[log(D(G(z)))]
三、GAN 训练挑战
GAN 的训练以其不稳定性和敏感性而闻名,被誉为"深度学习中最难训练的模型之一"。以下是主要的训练挑战及其解决方案。
3.1 模式崩溃(Mode Collapse)
模式崩溃是 GAN 训练中最常见的失败模式之一。当生成器发现某些样本更容易"欺骗"判别器时,它就会持续输出这些样本,导致生成数据的多样性严重不足。例如,在 MNIST 数据集上训练时,生成器可能只会生成数字"1"而忽略其他数字。从分布学习的角度看,模式崩溃意味着生成分布 pg 只覆盖了真实分布 pdata 的部分支撑集。
缓解模式崩溃的常用策略包括:使用小批量判别(Mini-batch Discrimination)让判别器同时观察多个样本;引入经验重放(Experience Replay)使生成器面对过去版本的判别器;使用多个生成器(MAD-GAN)强制覆盖不同的模式;以及采用谱归一化(Spectral Normalization)稳定判别器的 Lipschitz 常数。
模式崩溃的特征: 训练过程中生成器损失持续下降而判别器损失持续上升、同一批次中生成的样本高度相似、不同训练轮次生成的图片风格单一。如果发现这些信号,应立即采取缓解措施。
3.2 梯度消失
当判别器过于强大时,它可以轻松区分真假样本,导致 D(G(z)) 趋近于 0,log(1-D(G(z))) 趋近于 0,生成器的梯度非常小,几乎无法学习。这通常发生在训练初期判别器收敛过快,或者生成器与判别器的能力差距过大时。解决方案包括:确保使用非饱和损失、调节生成器和判别器的学习率比例(通常判别器学习率更低)、使用标签平滑(Label Smoothing)防止判别器过度自信、为判别器添加梯度惩罚(Gradient Penalty)。
3.3 不收敛与训练震荡
GAN 的对抗训练本质上是一个动态博弈过程,两个网络的参数更新相互依赖,容易陷入循环震荡而非收敛到纳什均衡。这表现为判别器和生成器的损失值交替上升下降,无法稳定。常用的稳定技巧包括:使用 Adam 优化器替代 SGD、为判别器添加 Dropout、使用 RMSProp 或 Adam 的不同 β 参数组合、采用 WGAN-GP 使用 Wasserstein 距离替代 JS 散度。
3.4 训练技巧总结
技巧 描述 适用场景
非饱和损失 G 使用 max log(D(G(z))) 所有 GAN 训练
标签平滑 真实标签使用 0.9 而非 1.0 防止 D 过度自信
梯度惩罚 约束 D 的 Lipschitz 常数 WGAN-GP
谱归一化 对 D 每层权重做谱范数归一化 SNGAN
特征匹配 G 的损失匹配 D 中间层特征 训练早期稳定
小批量判别 D 观察批统计量 缓解模式崩溃
学习率比例 D 学习率 < G 学习率 平衡对抗
# GAN 训练的核心循环(伪代码)
for epoch in range(num_epochs):
for batch_idx, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0)
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
# 训练判别器:最大化 log(D(x)) + log(1-D(G(z)))
d_real = discriminator(real_imgs)
d_fake = discriminator(fake_imgs.detach())
d_loss = -(torch.log(d_real + 1e-8).mean()
+ torch.log(1 - d_fake + 1e-8).mean())
d_loss.backward()
optimizer_d.step()
# 训练生成器:最大化 log(D(G(z)))(非饱和损失)
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
g_loss = -torch.log(discriminator(fake_imgs) + 1e-8).mean()
g_loss.backward()
optimizer_g.step()
实用建议: 训练 GAN 时遵循"1:1 或 1:2"的更新比例(每更新一次 G 更新 1-2 次 D),使用梯度裁剪(gradient clipping)防止梯度爆炸,监控 D 的输出分布——如果 D 的输出长期接近 0 或 1 则说明训练异常。
四、DCGAN:深度卷积生成对抗网络
DCGAN(Deep Convolutional GAN)由 Radford 等人于 2015 年提出,是将卷积神经网络引入 GAN 框架的里程碑式工作。DCGAN 首次证明了精心设计的卷积架构能够稳定地训练 GAN,并生成高质量的图像。DCGAN 的架构设计原则至今仍是大多数图像生成 GAN 的基础。
4.1 转置卷积上采样(Transposed Convolution)
生成器需要将低维的噪声向量(通常为 100 维)逐步上采样到高分辨率图像(如 64x64)。DCGAN 采用转置卷积(也称为反卷积 Deconvolution 或分数步长卷积 Fractionally-Strided Convolution)实现上采样。转置卷积通过将输入特征图的每个像素与可学习的卷积核相乘并叠加,实现特征图尺寸的倍增。
具体来说,DCGAN 生成器从 100 维噪声开始,经过全连接层重塑为 4x4x1024 的特征图,然后通过 4 层转置卷积依次上采样至 8x8、16x32、32x32、64x64,通道数从 1024 逐步减少到 3(RGB 图像)。每层转置卷积后接 BatchNorm 和 ReLU 激活函数,最后一层使用 Tanh 将像素值映射到 [-1, 1] 范围。
# DCGAN 生成器架构(PyTorch 实现)
class DCGenerator(nn.Module):
def __init__(self, latent_dim=100, channels=3, feature_maps=64):
super().__init__()
self.model = nn.Sequential(
# 输入: (latent_dim, 1, 1) -> 全连接 -> 重塑
nn.ConvTranspose2d(latent_dim, feature_maps*8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps*8),
nn.ReLU(True),
# 状态: (feature_maps*8, 4, 4)
nn.ConvTranspose2d(feature_maps*8, feature_maps*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps*4),
nn.ReLU(True),
# 状态: (feature_maps*4, 8, 8)
nn.ConvTranspose2d(feature_maps*4, feature_maps*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps*2),
nn.ReLU(True),
# 状态: (feature_maps*2, 16, 16)
nn.ConvTranspose2d(feature_maps*2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# 状态: (feature_maps, 32, 32)
nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
nn.Tanh()
# 输出: (channels, 64, 64)
)
def forward(self, z):
z = z.view(z.size(0), -1, 1, 1)
return self.model(z)
4.2 判别器架构设计
DCGAN 判别器是生成器的镜像结构,采用步长卷积(Strided Convolution)替代池化层进行下采样。与生成器相反,判别器使用 LeakyReLU 激活函数(α=0.2),取消 BatchNorm 在输入层的应用。最终通过全连接层和 Sigmoid 输出单个标量。
# DCGAN 判别器架构(PyTorch 实现)
class DCDiscriminator(nn.Module):
def __init__(self, channels=3, feature_maps=64):
super().__init__()
self.model = nn.Sequential(
# 输入: (channels, 64, 64)
nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_maps, 32, 32)
nn.Conv2d(feature_maps, feature_maps*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps*2),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_maps*2, 16, 16)
nn.Conv2d(feature_maps*2, feature_maps*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps*4),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_maps*4, 8, 8)
nn.Conv2d(feature_maps*4, feature_maps*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps*8),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_maps*8, 4, 4)
nn.Conv2d(feature_maps*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x).view(-1, 1).squeeze(1)
4.3 DCGAN 的五大设计原则
用步长卷积替代池化层: 判别器使用步长卷积(stride=2)进行下采样,生成器使用转置卷积进行上采样,让网络自主学习空间下/上采样。
在生成器和判别器中广泛使用 BatchNorm: BN 将每层输入归一化为零均值单位方差,稳定训练过程,防止梯度消失/爆炸。但生成器输出层和判别器输入层不使用 BN。
移除全连接层: 除生成器输入层外,全部使用卷积层,减少参数量并利用卷积的局部连接特性。
生成器使用 ReLU,判别器使用 LeakyReLU: ReLU 在生成器中提供稀疏梯度,LeakyReLU 防止判别器梯度全部消失。
生成器输出使用 Tanh: 将像素值约束在 [-1, 1] 范围,比 Sigmoid 提供更强的梯度信号。
五、GAN 改进与变体
自原始 GAN 提出以来,研究者们从多个方向对其进行了改进,催生了大量优秀的变体模型。以下是其中最具代表性的工作。
5.1 CGAN:条件生成对抗网络
条件 GAN(Conditional GAN)由 Mirza 和 Osindero 于 2014 年提出,是最早的 GAN 变体之一。CGAN 的核心思想是在生成器和判别器的输入中加入条件信息 y(如类别标签、文本描述或图像),使生成过程可控。通过将随机噪声 z 和条件 y 拼接后送入生成器,CGAN 能够生成符合特定条件的数据。
# CGAN 生成器接收噪声和条件标签
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=100, num_classes=10, img_channels=1,
img_size=28):
super().__init__()
self.label_embed = nn.Embedding(num_classes, latent_dim)
self.model = nn.Sequential(
nn.Linear(latent_dim * 2, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_channels * img_size * img_size),
nn.Tanh()
)
def forward(self, z, labels):
label_emb = self.label_embed(labels)
combined = torch.cat([z, label_emb], dim=1)
img = self.model(combined)
return img.view(img.size(0), -1, 28, 28)
CGAN 的提出开启了"可控生成"的研究方向,为此后的 Pix2Pix、CycleGAN、StyleGAN 等工作奠定了基础。
5.2 Pix2Pix:图像到图像翻译
Pix2Pix 由 Phillip Isola 等人于 2016 年提出,是条件 GAN 在图像翻译任务中的经典应用。它使用 U-Net 作为生成器(而非简单的编码器-解码器),PatchGAN 作为判别器(对图像局部块进行判别而非整图)。损失函数结合了 CGAN 的对抗损失和 L1 损失,其中 L1 损失鼓励生成结果与真实标签在像素级上接近。
# Pix2Pix 损失函数
L_G = L_cGAN(G,D) + lambda * L1(G)
# 其中 L1(G) = E_x,y[||y - G(x)||_1]
# lambda 通常取 100
Pix2Pix 开创了有监督图像翻译范式,被广泛应用于草图到照片、分割图到真实图、黑白图像上色、白天到夜晚等任务。
5.3 CycleGAN:循环一致生成对抗网络
CycleGAN 由 Zhu 等人于 2017 年提出,解决了无配对图像翻译问题。其核心创新是循环一致性损失(Cycle Consistency Loss):假设我们要将 X 域图像转换为 Y 域,除了前向生成器 G: X→Y 和判别器 D_Y 之外,还有一个反向生成器 F: Y→X 和判别器 D_X。前向循环一致性要求 F(G(x)) ≈ x,反向循环一致性要求 G(F(y)) ≈ y。
# CycleGAN 总损失
L(G,F,D_X,D_Y) = L_GAN(G,D_Y,X,Y) # 前向 GAN 损失
+ L_GAN(F,D_X,Y,X) # 反向 GAN 损失
+ lambda * L_cyc(G,F) # 循环一致性损失
# 循环一致性损失函数
L_cyc(G,F) = E_x[||F(G(x)) - x||_1]
+ E_y[||G(F(y)) - y||_1]
CycleGAN 的突破在于仅使用无配对数据即可实现风格迁移,消除了 Pix2Pix 对配对数据的依赖。它被广泛应用于照片风格化(照片→莫奈画作、照片→梵高风格)、季节转换(夏天→冬天)、动物品种转换(马→斑马)等任务。
5.4 StyleGAN:风格化生成对抗网络
StyleGAN 由 NVIDIA 的 Karras 等人于 2018 年提出,是 GAN 图像生成质量的重要里程碑。它引入了几项关键创新:映射网络(Mapping Network)、自适应实例归一化(AdaIN)、噪声注入和渐进增长训练。
映射网络: 将随机噪声 z 通过 8 层 MLP 映射为中间潜码 w,解耦了潜码的线性特征,使生成过程更加可控。通过改变 w 的不同维度,可以独立控制面部朝向、年龄、发型、肤色等视觉属性。
自适应实例归一化(AdaIN): 在合成网络的每个卷积层后,将特征图的均值和方差与风格向量对齐:
# AdaIN 公式
AdaIN(x_i, y) = sigma(y) * (x_i - mu(x_i)) / sigma(x_i) + mu(y)
# 其中 x_i 是特征图的第 i 个通道
# mu(x_i) 和 sigma(x_i) 是 x_i 的均值和标准差
# mu(y) 和 sigma(y) 是由风格向量 y 通过仿射变换得到的
噪声注入: 在每层卷积后添加独立的高斯噪声图,为生成图像增加细微随机变化(如皮肤毛孔、头发纹理、背景细节),增强生成逼真度。
渐进增长: 训练从 4x4 低分辨率开始,逐步向网络中添加更高分辨率的层(8x8 → 16x16 → ... → 1024x1024),使训练过程从全局结构逐步过渡到局部细节,显著提高了训练稳定性和生成质量。
5.5 SRGAN:超分辨率生成对抗网络
SRGAN(Super-Resolution GAN)由 Ledig 等人于 2017 年提出,将 GAN 引入图像超分辨率任务。SRGAN 的核心创新是感知损失(Perceptual Loss),它结合了内容损失(Content Loss)和对抗损失(Adversarial Loss),并使用预训练的 VGG 网络提取高层特征计算内容相似度。
# SRGAN 感知损失
L_perceptual = L_content + 1e-3 * L_GAN
# VGG 内容损失(使用预训练 VGG19 的 relu5_4 层)
L_content = E[||VGG(I_HR) - VGG(G(I_LR))||^2]
# 对抗损失
L_GAN = -E[log(D(G(I_LR)))]
与传统的基于像素级 MSE 损失的 SR 方法相比,SRGAN 生成的超分辨率图像在感知质量上显著优于传统方法,虽然 PSNR 指标可能略低,但人眼观感更加锐利自然。
六、GAN 应用场景
6.1 图像生成与编辑
GAN 最成熟的应用领域是图像生成。StyleGAN 系列能够生成人眼无法分辨的高分辨率人脸图像。在图像编辑方面,GAN 支持语义编辑——通过修改潜码的特定维度来改变图像的特定属性(如修改人脸年龄、表情、发型),而保持其他内容不变。代表性的工具包括 StyleCLIP、InterfaceGAN 等。
6.2 超分辨率重建
SRGAN、ESRGAN、Real-ESRGAN 等模型广泛应用于图像和视频的超分辨率重建。GAN 的对抗训练使超分辨率结果保留更多的高频细节(纹理、边缘),克服了传统 MSE 损失导致的结果过于平滑的问题。实际应用包括老照片修复、监控视频清晰化、医疗影像增强等。
6.3 风格迁移
基于 CycleGAN 的神经风格迁移无需配对训练数据,可将任意图像转换为目标风格。除了艺术风格迁移(照片→油画)外,还被应用于图像色调映射、纹理合成、跨季节图像转换等。相比传统的基于优化的风格迁移方法,GAN 方法推理速度快(单次前向传播),适合实时应用。
6.4 数据增强
在医学影像、工业缺陷检测等标注数据稀缺的场景,GAN 被用于生成逼真的合成数据扩充训练集。通过条件 GAN(如控制病变类型和位置),可以生成特定类别和分布的合成样本,有效提升下游分类或检测模型的泛化性能。研究表明,GAN 增强在样本量极少的场景(如每类不足 100 张)中效果尤为显著。
6.5 异常检测
基于 AnoGAN 和 Efficient GAN 的异常检测方法利用 GAN 学习正常数据的分布,将测试样本映射到潜空间后重建,通过比较原图与重建图的差异来检测异常区域。该方法在工业缺陷检测、医疗影像异常筛查等领域展现出良好的效果,尤其适用于异常样本难以收集的场景。
GAN 的优势
生成质量高,图像清晰锐利
隐空间具有语义可解释性
适用于无监督和半监督学习
支持任意分辨率的多模态生成
GAN 的局限性
训练不稳定,调参困难
模式崩溃问题
缺乏对生成多样性的显式度量
评估指标(FID/IS)不能完美反映质量
易出现梯度消失
七、PyTorch 完整 GAN 实现
以下是一个完整的 DCGAN 训练脚本,包含数据加载、模型定义、训练循环和可视化组件。
# 完整的 DCGAN 训练实现(PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import os
import numpy as np
# ============ 超参数 ============
LATENT_DIM = 100
CHANNELS = 1
IMG_SIZE = 64
BATCH_SIZE = 128
LR = 0.0002
BETA1 = 0.5
BETA2 = 0.999
NUM_EPOCHS = 50
FEATURE_MAPS_G = 64
FEATURE_MAPS_D = 64
SAMPLE_INTERVAL = 500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============ 生成器 ============
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
# (latent_dim, 1, 1)
nn.ConvTranspose2d(LATENT_DIM, FEATURE_MAPS_G*8, 4, 1, 0, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_G*8),
nn.ReLU(True),
# (feature_maps_g*8, 4, 4)
nn.ConvTranspose2d(FEATURE_MAPS_G*8, FEATURE_MAPS_G*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_G*4),
nn.ReLU(True),
# (feature_maps_g*4, 8, 8)
nn.ConvTranspose2d(FEATURE_MAPS_G*4, FEATURE_MAPS_G*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_G*2),
nn.ReLU(True),
# (feature_maps_g*2, 16, 16)
nn.ConvTranspose2d(FEATURE_MAPS_G*2, FEATURE_MAPS_G, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_G),
nn.ReLU(True),
# (feature_maps_g, 32, 32)
nn.ConvTranspose2d(FEATURE_MAPS_G, CHANNELS, 4, 2, 1, bias=False),
nn.Tanh()
# (channels, 64, 64)
)
def forward(self, z):
return self.model(z.view(z.size(0), -1, 1, 1))
# ============ 判别器 ============
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(CHANNELS, FEATURE_MAPS_D, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(FEATURE_MAPS_D, FEATURE_MAPS_D*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_D*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(FEATURE_MAPS_D*2, FEATURE_MAPS_D*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_D*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(FEATURE_MAPS_D*4, FEATURE_MAPS_D*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(FEATURE_MAPS_D*8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(FEATURE_MAPS_D*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x).view(-1, 1).squeeze(1)
# ============ 权重初始化 ============
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# ============ 数据加载 ============
transform = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # 映射到 [-1, 1]
])
dataset = datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
# ============ 模型初始化 ============
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
generator.apply(weights_init)
discriminator.apply(weights_init)
# ============ 损失函数与优化器 ============
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=LR,
betas=(BETA1, BETA2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR,
betas=(BETA1, BETA2))
# 固定噪声用于训练过程可视化
fixed_noise = torch.randn(64, LATENT_DIM, device=DEVICE)
real_label = 1.0
fake_label = 0.0
# ============ 训练循环 ============
os.makedirs("output", exist_ok=True)
G_losses = []
D_losses = []
for epoch in range(NUM_EPOCHS):
for batch_idx, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(DEVICE)
# ---------- 训练判别器 ----------
discriminator.zero_grad()
label_real = torch.full((batch_size,), real_label,
dtype=torch.float, device=DEVICE)
label_fake = torch.full((batch_size,), fake_label,
dtype=torch.float, device=DEVICE)
# 真实样本损失
output_real = discriminator(real_imgs)
d_loss_real = criterion(output_real, label_real)
# 假样本损失
noise = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
fake_imgs = generator(noise)
output_fake = discriminator(fake_imgs.detach())
d_loss_fake = criterion(output_fake, label_fake)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# ---------- 训练生成器 ----------
generator.zero_grad()
output_fake = discriminator(fake_imgs)
g_loss = criterion(output_fake, label_real) # G 希望 D 认为假图是真的
g_loss.backward()
optimizer_G.step()
# ---------- 记录 ----------
G_losses.append(g_loss.item())
D_losses.append(d_loss.item())
if batch_idx % SAMPLE_INTERVAL == 0:
print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] "
f"[Batch {batch_idx}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.4f}] "
f"[G loss: {g_loss.item():.4f}]")
# 保存生成的样本图像
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
grid = make_grid(fake, nrow=8, normalize=True,
value_range=(-1, 1))
save_image(grid,
f"output/epoch_{epoch+1}_batch_{batch_idx}.png")
# ============ 保存模型 ============
torch.save(generator.state_dict(), "output/generator_final.pth")
torch.save(discriminator.state_dict(), "output/discriminator_final.pth")
print("训练完成!模型已保存至 output/ 目录。")
运行提示: 以上代码在 MNIST 数据集上训练。如需训练 RGB 图像(如 CelebA),将 CHANNELS 改为 3 并更新 Dataset 路径。GPU 训练推荐 batch_size >= 64,使用 torch.cuda.amp 混合精度训练可提速 30-50%。
7.1 GAN 评估指标
GAN 的评估是一个开放研究问题,目前最常用的两个指标为:
FID(Fréchet Inception Distance): 使用预训练 Inception v3 网络提取特征,计算真实分布和生成分布特征向量的 Frechet 距离。FID 越低表示生成质量越好。FID 对模式崩溃较敏感,是目前最主流的评估指标。
IS(Inception Score): 基于 Inception v3 的分类置信度计算。高 IS 要求生成样本类别清晰且多样化。但 IS 对分布内样本有偏(在 ImageNet 上训练的网络对非 ImageNet 类别的评估不可靠)。
# FID 计算(简化版,使用 torchmetrics)
# pip install torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.transforms as T
def compute_fid(real_loader, generator, device, num_samples=10000):
fid = FrechetInceptionDistance(feature=2048).to(device)
transform = T.Compose([
T.Resize((299, 299)),
T.Lambda(lambda x: x.expand(3, -1, -1)), # 灰度图转3通道
])
count = 0
for real_imgs, _ in real_loader:
real_imgs = transform(real_imgs).to(device)
fid.update(real_imgs, real=True)
count += real_imgs.size(0)
if count >= num_samples:
break
count = 0
with torch.no_grad():
while count < num_samples:
z = torch.randn(BATCH_SIZE, LATENT_DIM, device=device)
fake = generator(z)
fake = transform(fake)
fid.update(fake, real=False)
count += fake.size(0)
return float(fid.compute())
八、核心要点总结
1. GAN 的基本原理: 生成器 G 与判别器 D 通过 Minimax 博弈进行对抗训练,生成器学习数据分布,判别器区分真假样本,最终达到纳什均衡。原始的 GAN 价值函数为 min_G max_D V(D,G) = E[log D(x)] + E[log(1-D(G(z)))]。
2. 训练挑战与解决方案: GAN 面临模式崩溃(生成缺乏多样性)、梯度消失(D 过于强大)、不收敛(损失震荡)三大主要挑战。对策包括非饱和损失、标签平滑、梯度惩罚、谱归一化、学习率调节等技巧。
3. DCGAN 架构原则: 用步长卷积/转置卷积替代池化、广泛使用 BatchNorm、移除全连接层、G 使用 ReLU/D 使用 LeakyReLU、输出使用 Tanh。这五大原则奠定了图像生成 GAN 的架构基础。
4. 重要变体与改进: CGAN(条件控制生成)、Pix2Pix(有监督图像翻译)、CycleGAN(无监督风格迁移/循环一致性)、StyleGAN(映射网络/AdaIN/渐进增长)、SRGAN(感知损失的超分辨率)。
5. 应用场景: GAN 广泛应用于图像生成(StyleGAN)、超分辨率(ESRGAN)、风格迁移(CycleGAN)、数据增强(合成样本扩充训练集)、异常检测(AnoGAN)等领域。
6. 评估指标: FID(Frechet Inception Distance)是目前最主流的 GAN 评估方法,FID 越低越好。IS(Inception Score)在特定场景也有使用,但存在分布内偏差。此外还有 Precision-and-Recall 等指标可各自评估保真度和多样性。
九、进一步思考
9.1 GAN 与扩散模型的对比
近年来,扩散模型(Diffusion Models,如 DDPM、Stable Diffusion)在图像生成质量上超越了 GAN。扩散模型通过逐渐向数据添加噪声再反向去噪来生成样本,其训练更稳定,模式覆盖更全面。然而,扩散模型的推理速度较慢(需要数十到数百步迭代),而 GAN 只需一次前向传播即可生成样本。在延迟敏感的应用(如实时图像编辑、视频生成)中,GAN 仍有不可替代的优势。
9.2 未来发展方向
GAN 与扩散模型的融合: 结合 GAN 的快速推理能力与扩散模型的高质量输出,如采用扩散过程作为 GAN 的隐空间先验。
3D 生成: 将 GAN 扩展到体素、点云、NeRF 等 3D 表示,实现三维物体的可控生成(如 EG3D、PanoGAN)。
视频生成: 将 GAN 扩展到时间维度,生成连贯的视频序列(如 MoCoGAN、VideoGAN)。
多模态生成: 结合文本、图像、音频等多种模态的条件 GAN,实现跨模态生成与对齐。
可解释性与可控性: 进一步揭示 GAN 隐空间的语义结构,实现细粒度、可解释的图像编辑。
"GAN 可能是过去十年中深度学习领域最具创造性的想法。它让机器不再只是观察世界,而是开始创造世界。" —— Yann LeCun