PyTorch DataLoader与训练循环

深度学习专题 · 数据加载与训练流程完整实现

专题:PyTorch数据加载与训练管道

关键词:深度学习, DataLoader, Dataset, Sampler, transforms, 训练循环, num_workers, collate_fn, PyTorch

一、概述

在PyTorch深度学习项目中,数据加载训练循环是最基础也是最关键的组成部分。无论任务类型是图像分类、目标检测、NLP还是语音识别,整个深度学习流程都遵循相同的高层范式:定义数据集、构建数据加载管道、编写训练循环。本文将系统性地从底层到高层逐一剖析PyTorch的DataLoader机制和训练循环的完整实现,涵盖从Dataset接口设计、Sampler采样策略、数据变换与增强,到完整的训练、验证、保存与加载等全链路内容。

PyTorch的数据加载系统在设计上遵循模块化可组合的原则。Dataset负责数据的读取与标注,Sampler负责决定样本的采样顺序,DataLoader负责将前两者组合起来并提供多进程预取、批量化、内存优化等功能。这种松耦合架构允许开发者灵活地替换任意组件而不影响其他部分,极大地增强了代码的可复用性和可维护性。

Dataset (原始数据)
↓ 经过 transforms 变换
Sampler → DataLoader (batch_size, num_workers, collate_fn)

训练循环 (epoch → batch → forward → loss → backward → optimizer.step)

版本说明:本文基于 PyTorch 2.x 版本,所有代码示例均与 PyTorch 1.8+ 兼容。推荐使用 Python 3.8+ 和 PyTorch 2.0+ 以获得最佳性能体验。

二、Dataset 类详解

Dataset是PyTorch数据加载体系中的基础抽象。它定义了如何从原始存储介质(磁盘文件、数据库、网络等)中读取单个样本及其对应的标签。PyTorch提供了两种Dataset风格:Map-StyleIterable-Style,二者分别适用于不同的数据访问模式。

2.1 Map-Style Dataset

Map-Style Dataset是最常用的数据集类型,它实现了__len____getitem__两个核心方法,行为类似于Python的列表或字典——通过索引进行随机访问。DataLoader在Map-Style数据集上会根据Sampler生成的索引列表调用dataset[idx]来获取样本,这使得数据打乱(shuffle)和分布式采样(distributed sampling)的实现变得非常自然。

以下是一个完整的Map-Style Dataset实现示例,用于加载图像分类数据集:

import torch from torch.utils.data import Dataset from PIL import Image import os import json class ImageClassificationDataset(Dataset): """图像分类数据集(Map-Style 实现)""" def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.samples = [] self.classes = sorted(os.listdir(root_dir)) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} for cls in self.classes: cls_dir = os.path.join(root_dir, cls) for fname in os.listdir(cls_dir): self.samples.append((os.path.join(cls_dir, fname), self.class_to_idx[cls])) def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label

核心方法说明:

__len__:返回数据集的总样本数,DataLoader据此计算每个epoch的步数(batches per epoch)。

__getitem__(idx):根据索引idx返回单个样本(通常为(image, label)元组)。DataLoader的内部索引机制会调用此方法,然后由collate_fn将N个单个样本堆叠为batch。

2.2 Iterable-Style Dataset

Iterable-Style Dataset适用于无法预先计算长度数据流式到达的场景,例如从数据库中流式读取、读取超大文件无法全部加载到内存、或在强化学习中动态生成训练数据。与Map-Style不同,Iterable-Style只需实现__iter__方法,返回一个迭代器对象。

from torch.utils.data import IterableDataset class StreamTextDataset(IterableDataset): """流式文本数据集(Iterable-Style 实现)""" def __init__(self, file_path, chunk_size=1024): self.file_path = file_path self.chunk_size = chunk_size def __iter__(self): with open(self.file_path, 'r') as f: while True: chunk = f.read(self.chunk_size) if not chunk: break yield chunk

重要注意事项:Iterable-Style Dataset在使用num_workers > 1时,每个worker都会独立地重新创建迭代器,导致数据重复读取。开发者需要在__iter__中自行处理worker的信息(通过torch.utils.data.get_worker_info())来实现数据切分,确保每个worker读取不同的数据子集,避免数据重复。

三、DataLoader 核心参数详解

DataLoader是连接Dataset与训练循环之间的桥梁。它会采样器(Sampler)决定样本顺序,使用多进程(num_workers)并行加载数据,并通过collate_fn将多个单样本组合为一个batch。深入理解DataLoader的每个参数对编写高性能训练管道至关重要。

3.1 参数速查表

参数 类型 默认值 作用描述
datasetDataset(必填)数据集实例,Map-Style或Iterable-Style
batch_sizeint1每个batch的样本数。设置为1时禁用自动批量化
shuffleboolFalse是否在每个epoch开始时打乱数据(仅对Map-Style生效)
samplerSampler / IterableNone自定义采样器。指定后shuffle参数将被忽略
batch_samplerSamplerNone返回整个batch索引的采样器。指定后batch_size/shuffle将被忽略
num_workersint0子进程数。0表示在主进程中加载,>0启用多进程预取
collate_fncallableNone将单个样本列表组合为batch的函数。默认实现堆叠tensor
pin_memoryboolFalse是否将tensor锁页到CPU内存,加速CPU→GPU传输
drop_lastboolFalse当样本数不能被batch_size整除时是否丢弃最后一个不完整batch
timeoutnumeric0获取一个batch的超时时间(秒),0表示无超时
worker_init_fncallableNone每个worker进程初始化时调用的函数,用于设置seeds等
prefetch_factorint2每个worker预取的batch数(num_workers>0时生效)
persistent_workersboolFalse是否在epoch间保持worker进程存活(避免反复创建销毁开销)

3.2 batch_size

batch_size决定每个迭代步送入模型的样本数量。较大的batch_size可以利用GPU并行计算的优势加速训练,但也需要更多的显存。实际使用中,batch_size的选择往往受限于GPU显存容量。一个经验法则是:从较小的batch_size(如16或32)开始逐渐增加,直至出现OOM错误为止。

需要注意的是,batch_size还会间接影响模型的泛化性能。研究表明,过大的batch_size可能导致模型收敛到尖锐极小值(sharp minima),降低泛化能力。实践中常用的做法是使用学习率预热(warmup)线性缩放法则(linear scaling rule)来应对大batch_size训练。

# 常见的 batch_size 配置示例 from torch.utils.data import DataLoader # 小 batch 训练(适合小显存或在线学习) loader_small = DataLoader(dataset, batch_size=16, shuffle=True) # 中等 batch(大多数任务的默认选择) loader_medium = DataLoader(dataset, batch_size=64, shuffle=True) # 大 batch 训练(需配合 learning rate 调整) loader_large = DataLoader(dataset, batch_size=256, shuffle=True)

3.3 shuffle 与 sampler

shuffle=True会在每个epoch开始时使用RandomSampler重新排列数据顺序。这对于防止模型学习到数据中的顺序偏置至关重要,特别是在训练数据存在某种排序规律(如所有类别按顺序排列)的情况下。

当需要更精细的采样控制时,可以直接传入sampler参数。指定sampler后,shuffle参数将被忽略。常见的自定义采样场景包括:类别不平衡数据集中的过采样/欠采样、分布式训练中的分片采样、特定比例的验证集采样等。

# 方法一:使用 shuffle 参数(简便方式) loader = DataLoader(dataset, batch_size=32, shuffle=True) # 方法二:使用 sampler 参数(更灵活,等价于上方) from torch.utils.data import RandomSampler sampler = RandomSampler(dataset, replacement=False) loader = DataLoader(dataset, batch_size=32, sampler=sampler)

3.4 num_workers

num_workers控制用于数据加载的子进程数量。当设置为大于0时,DataLoader会创建多个子进程异步地预取数据,主进程在进行模型前向/反向计算的同时,子进程在后台准备下一批数据,实现计算与IO的流水线并行

num_workers的最佳取值取决于具体场景:

经验法则:num_workers通常设置为CPU核心数或核心数的一半。从4开始,监控训练时的GPU利用率,如果GPU利用率经常低于80-90%,且CPU没有达到瓶颈,可以适当增加num_workers。如果GPU利用率持续偏低,还可以配合prefetch_factor参数增加预取数量。

3.5 collate_fn

collate_fn是DataLoader中最灵活的扩展点之一。其作用是将Dataset返回的N个单一样本(列表形式)组合成一个batch(通常为一个tensor或元组)。默认的collate_fn会将PIL Image或numpy数组自动转换为tensor,并在第0维(stack)进行堆叠。但在很多场景下,我们需要自定义collate_fn来处理不等长序列、不同尺寸的图像、或返回非tensor类型的样本。

# 自定义 collate_fn:处理变长序列(如NLP中的padding) import torch from torch.nn.utils.rnn import pad_sequence def collate_variable_length(batch): """ 对变长序列进行 padding 并返回 batch。 batch: [(tensor_seq1, label1), (tensor_seq2, label2), ...] """ sequences, labels = zip(*batch) # pad_sequence 会在序列末尾填充 0,使同一 batch 内长度一致 padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0) # 记录每条序列的实际长度(用于后续 mask) lengths = torch.tensor([len(seq) for seq in sequences]) labels = torch.tensor(labels) return padded_seqs, lengths, labels loader = DataLoader( dataset, batch_size=32, collate_fn=collate_variable_length, shuffle=True )
# 自定义 collate_fn:目标检测中的图像与边界框 def collate_detection(batch): """ 目标检测任务中,每张图像的边界框数量不同。 不能简单 stack,需保留列表格式。 """ images, targets = zip(*batch) # images 可以 stack 为 [B, C, H, W](假设尺寸相同) images = torch.stack(images, dim=0) # targets 保持为列表,每项为 [N_i, 5] 的 tensor (x1,y1,x2,y2,class) return images, list(targets)

3.6 pin_memory

pin_memory=True将数据分配在CPU的锁页内存(pinned memory)中。锁页内存不会被操作系统换出到磁盘,GPU可以直接通过DMA(Direct Memory Access)从锁页内存中读取数据,无需经过CPU临时缓冲区的中转,从而大幅提升CPU到GPU的数据传输速度。当使用GPU训练时,强烈建议开启此选项。

# GPU训练推荐配置 loader = DataLoader( dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, # 加速 CPU→GPU 传输 persistent_workers=True # 避免每个 epoch 结束后重建 worker )

3.7 worker_init_fn

当使用多进程数据加载(num_workers > 0)时,每个worker进程会继承主进程的随机数状态,导致所有worker生成相同的随机变换结果(如相同的随机裁剪位置)。worker_init_fn允许在每个worker进程启动时执行自定义初始化代码,通常用于设置独立的随机数种子,确保数据增强的随机性在各个worker之间是互异的。

import random import numpy as np import torch def worker_init_fn(worker_id): """为每个 worker 设置独立的随机种子""" seed = torch.initial_seed() % 2**32 np.random.seed(seed + worker_id) random.seed(seed + worker_id) loader = DataLoader( dataset, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn )

四、Sampler 详解

Sampler定义了数据集的采样策略,即决定在每个epoch中按照什么顺序访问数据样本。PyTorch提供了丰富的Sampler类以应对不同的训练需求。理解每种Sampler的适用场景是构建高效数据管道的必备知识。

4.1 SequentialSampler

按数据集的原始顺序依次采样,等价于DataLoader(shuffle=False)的行为。在验证和测试阶段通常使用此采样器,因为此时不需要打乱数据,且希望保持结果的可复现性。

from torch.utils.data import SequentialSampler sampler = SequentialSampler(dataset) # 等价于 DataLoader(dataset, shuffle=False) loader = DataLoader(dataset, batch_size=32, sampler=sampler)

4.2 RandomSampler

随机采样所有样本,支持有放回(replacement)和无放回两种模式。等价于DataLoader(shuffle=True)的行为。在训练阶段的标准选择。

from torch.utils.data import RandomSampler # 无放回随机采样(默认) sampler = RandomSampler(dataset, replacement=False) # 有放回随机采样(每个 epoch 的样本可能重复) sampler_with_replacement = RandomSampler( dataset, replacement=True, num_samples=1000 )

4.3 SubsetRandomSampler

从给定的索引子集中进行随机采样。在划分训练集和验证集时非常实用,无需创建两个独立的Dataset副本,只需维护不同的索引列表即可。

from torch.utils.data import SubsetRandomSampler # 手动划分训练/验证索引 dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(0.8 * dataset_size) np.random.shuffle(indices) train_indices, val_indices = indices[:split], indices[split:] # 使用 SubsetRandomSampler 创建各自的 DataLoader train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler) val_loader = DataLoader(dataset, batch_size=32, sampler=val_sampler)

4.4 WeightedRandomSampler

加权随机采样,用于处理类别不平衡问题。每个样本被抽中的概率与其权重成正比,通过调整采样权重使模型在每个batch中看到更均衡的类别分布,从而避免模型偏向训练数据中的多数类。

from torch.utils.data import WeightedRandomSampler # 假设 labels 为所有样本的类别标签列表 labels = np.array([ds[1] for ds in dataset]) # 计算每个类别的样本数量 class_counts = np.bincount(labels) # 权重与样本数成反比:少数类获得更大的采样权重 class_weights = 1.0 / class_counts sample_weights = class_weights[labels] # 构建加权采样器,每个 epoch 采集 10000 个样本 sampler = WeightedRandomSampler( weights=sample_weights, num_samples=10000, replacement=True # 有放回,确保少数类能被多次采样 ) loader = DataLoader(dataset, batch_size=32, sampler=sampler)

4.5 DistributedSampler

DistributedSampler是分布式训练中的关键组件。它将数据集划分为不相交的子集,每个GPU进程只处理自己的那部分数据,确保所有GPU看到的数据没有重叠。使用DistributedSampler时需要注意在每个epoch开始时调用sampler.set_epoch(epoch)来打乱数据分区,否则所有epoch数据顺序一致。

from torch.utils.data.distributed import DistributedSampler # 在多 GPU 分布式训练中使用 sampler = DistributedSampler( dataset, num_replicas=world_size, # GPU 总数量 rank=local_rank # 当前 GPU 的编号 ) loader = DataLoader(dataset, batch_size=32, sampler=sampler) # 每个 epoch 开始时必须设置: for epoch in range(num_epochs): sampler.set_epoch(epoch) # 重要:每个 epoch 需重新打乱 for batch in loader: # 训练代码... pass
Sampler 适用场景 关键参数
SequentialSampler验证/测试,不需要打乱数据shuffle=False
RandomSampler标准训练,需要打乱数据replacement, num_samples
SubsetRandomSampler数据集划分(训练/验证子集)indices 列表
WeightedRandomSampler类别不平衡矫正weights, num_samples, replacement
DistributedSampler多GPU分布式训练num_replicas, rank, shuffle

五、数据变换与增强

数据变换(transforms)是数据预处理和增强的核心工具。torchvision.transforms提供了丰富的图像变换函数,包括标准化、尺寸调整、随机裁剪、颜色抖动、旋转翻转等。合理的数据增强策略不仅能增加训练数据的多样性,还能显著提升模型的泛化能力鲁棒性

5.1 torchvision.transforms 与 Compose

transforms.Compose将多个变换操作串联成一个管道,Dataset在返回样本时自动依次执行。这种声明式组合的方式使代码清晰简洁,便于实验不同增强策略的排列组合。

from torchvision import transforms # 训练集:包含数据增强 train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 验证集:仅做必要的尺寸调整和标准化(不包含随机增强) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ])

5.2 ToTensor 与 Normalize

transforms.ToTensor()将PIL Image或numpy数组转换为PyTorch tensor,并自动执行两个关键操作:将H×W×C的通道顺序转换为C×H×W(PyTorch卷积层要求的格式),以及将像素值从0~255缩放到0.0~1.0。

transforms.Normalize(mean, std)执行标准化操作:output = (input - mean) / std。通常使用的mean和std值来源于ImageNet数据集的统计值。标准化使各通道数据具有零均值和单位方差,有助于加速模型收敛。

# 标准化的数学含义 # output[channel] = (input[channel] - mean[channel]) / std[channel] # ImageNet 标准化参数 normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], # RGB 三通道均值 std=[0.229, 0.224, 0.225] # RGB 三通道标准差 )

5.3 数据增强技术详解

RandomCrop

从图像中随机裁剪一个子区域。RandomResizedCrop是RandomCrop的升级版,它会先随机缩放图像再裁剪,从而模拟不同尺度的目标。这是提升模型尺度不变性的有效手段。

ColorJitter

随机调整图像的亮度(brightness)对比度(contrast)饱和度(saturation)色调(hue)。增强模型对光照条件变化的鲁棒性。参数取值范围通常是0~1之间的浮点数,表示变化的幅度。

# ColorJitter 参数详解 color_jitter = transforms.ColorJitter( brightness=0.2, # 亮度变化范围: [1-0.2, 1+0.2] = [0.8, 1.2] contrast=0.2, # 对比度变化范围: [0.8, 1.2] saturation=0.2, # 饱和度变化范围: [0.8, 1.2] hue=0.1 # 色调变化范围: [-0.1, 0.1](不超过0.5) )

RandomHorizontalFlip

以给定概率(通常为0.5)对图像进行水平翻转。这是图像分类中最基础、最常用的增强方法之一,适用于大多数非对称性不强的场景(如自然图像、医学影像等)。但不适用于文本OCR、人脸识别等对方向敏感的任务。

# 常见的数据增强组合 train_transform = transforms.Compose([ # 1. 几何变换 transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(degrees=15), # 2. 颜色变换 transforms.ColorJitter(brightness=0.1, contrast=0.1), # 3. 高级增强(可选) # transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)), # 4. 张量变换 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

增强选择原则:

  • 保留语义不变性:增强后的图像不应改变其语义类别。垂直翻转会使"汽车"变"倒置汽车",通常不推荐。
  • 强度适中:过强的增强会使模型难以学习(如过度的颜色变化使猫看起来像狗)。
  • 下游任务特性:检测任务可使用更丰富的增强(如Mosaic、MixUp),分类任务则需谨慎。
  • 验证集不做增强:验证/测试时只进行必要的Resize、CenterCrop和Normalize。

六、训练循环完整实现

训练循环是整个深度学习流程的核心引擎。一个标准的PyTorch训练循环包含以下关键步骤:数据迭代、前向传播、损失计算、反向传播、参数更新。下面将从最简单的训练循环起步,逐步增加功能和最佳实践,最终形成一个完备的训练管道。

6.1 基础训练循环

# 基础三件套:模型、损失函数、优化器 model = MyModel() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) num_epochs = 10 for epoch in range(num_epochs): model.train() # 设置为训练模式(启用 Dropout、BatchNorm 等) running_loss = 0.0 for batch_idx, (inputs, labels) in enumerate(train_loader): # 将数据移动到 GPU inputs, labels = inputs.to(device), labels.to(device) # 1. 梯度清零(防止累加) optimizer.zero_grad() # 2. 前向传播 outputs = model(inputs) loss = criterion(outputs, labels) # 3. 反向传播 loss.backward() # 4. 参数更新 optimizer.step() # 统计 loss running_loss += loss.item() if batch_idx % 100 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}], Loss: {loss.item():.4f}') epoch_loss = running_loss / len(train_loader) print(f'Epoch [{epoch+1}/{num_epochs}] Average Loss: {epoch_loss:.4f}')

6.2 梯度裁剪

梯度裁剪(gradient clipping)用于防止梯度爆炸问题,特别是在训练RNN、LSTM或Transformer等循环/自注意力架构时尤为关键。梯度裁剪通过设置一个最大梯度范数阈值,在反向传播后对梯度进行缩放,使梯度的L2范数不超过该阈值。

# 梯度裁剪的两种方式 # 方式一:按范数裁剪(最常用) torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=1.0, # 最大梯度范数 norm_type=2.0 # L2 范数 ) # 方式二:按值裁剪(简单粗暴) torch.nn.utils.clip_grad_value_( model.parameters(), clip_value=0.5 # 梯度值限制在 [-0.5, 0.5] 范围内 ) # 在实际训练循环中的位置 loss.backward() # ← 梯度裁剪放在 backward() 之后,optimizer.step() 之前 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()

关键顺序提醒:梯度裁剪必须放在loss.backward()之后、optimizer.step()之前执行。因为backward()计算了梯度,clip操作修改梯度值,而step()使用修改后的梯度更新参数。将clip放在forward之前或step之后都是错误的。

6.3 学习率调度

学习率调度(learning rate scheduling)是训练过程中的重要调参手段。合理的学习率衰减策略可以使模型在训练后期更精细地收敛到最优解。PyTorch提供了多种调度器。

from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR # 方案一:余弦退火(适合 CV 任务) scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6) # 方案二:验证 loss 不下降时衰减(适合 NLP 任务) scheduler = ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=5 ) # 方案三:固定步长衰减 scheduler = StepLR(optimizer, step_size=30, gamma=0.1) # 训练循环中使用(以 ReduceLROnPlateau 为例) for epoch in range(num_epochs): train_one_epoch() val_loss = validate() # ReduceLROnPlateau 需要传入 val_loss scheduler.step(val_loss) # 其他调度器使用 scheduler.step() 无参数

七、验证循环

验证循环用于评估模型在未见数据上的表现。与训练循环相比,验证循环有以下几个关键区别:不需要计算梯度(大幅减少显存和计算量)、不需要反向传播和参数更新、使用model.eval()切换到评估模式(禁用Dropout、固定BatchNorm统计量)、使用torch.no_grad()上下文管理器禁用梯度计算。

def validate(model, val_loader, criterion, device): """完整的验证循环实现""" model.eval() # 切换到评估模式 val_loss = 0.0 correct = 0 total = 0 # torch.no_grad() 禁用梯度计算,节省显存和加速 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) # 前向传播(无梯度跟踪) outputs = model(inputs) loss = criterion(outputs, labels) # 统计 loss val_loss += loss.item() # 计算准确率 _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() avg_loss = val_loss / len(val_loader) accuracy = 100.0 * correct / total print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%') return avg_loss, accuracy
# 训练循环中整合验证 best_acc = 0.0 for epoch in range(num_epochs): # 训练阶段 train_one_epoch(model, train_loader, criterion, optimizer, device) # 验证阶段 val_loss, val_acc = validate(model, val_loader, criterion, device) # 保存最佳模型 if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') print(f'✓ 保存最佳模型,验证准确率: {val_acc:.2f}%') # 学习率调度 scheduler.step(val_loss)

训练模式与评估模式的区别:

model.train():启用Dropout(随机丢弃神经元)、启用BatchNorm的逐batch统计量计算、启用数据增强相关的随机操作。

model.eval():禁用Dropout(使用全部神经元)、使用BatchNorm的累积运行统计量(running mean/var)、不进行数据增强。如果在验证时忘记切换模式,模型性能可能会下降5-20%。

7.1 torch.no_grad 详解

torch.no_grad()是一个上下文管理器,在其作用域内的所有计算都不会被自动求导系统追踪。这意味着:

八、模型保存与加载

模型持久化是训练流程中的重要环节。PyTorch提供了灵活且高效的模型保存与加载机制。理解state_dict的概念以及不同保存策略的适用场景,对于构建生产级训练管道至关重要。

8.1 保存与加载的最佳实践

# ========== 保存最佳实践 ========== # 方案一:仅保存模型权重(推荐用于推理部署) torch.save(model.state_dict(), 'model_weights.pth') # 方案二:保存完整 checkpoint(推荐用于断点续训) checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'best_acc': best_acc, 'scheduler_state_dict': scheduler.state_dict(), # 恢复学习率状态 } torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth') # ========== 加载最佳实践 ========== # 加载权重(需先创建模型实例) model = MyModel() model.load_state_dict(torch.load('model_weights.pth', map_location=device)) # 加载 checkpoint 恢复训练 checkpoint = torch.load('checkpoint_epoch_10.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 best_acc = checkpoint['best_acc'] # 如果需要恢复学习率调度器状态 if 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

8.2 选择保存策略

场景 推荐方式 理由
模型部署/推理仅保存 state_dict文件小、加载快、跨平台兼容性好
断点续训保存完整 checkpoint可恢复训练状态、学习率调度器状态等
实验对比保存 checkpoint + 配置文件记录超参数、数据路径等元信息,便于复现
分布式训练保存 model.module.state_dict()DDP包装后的模型需通过.module访问原始state_dict

注意事项:使用map_location=device参数可以在加载时自动将模型权重映射到指定设备(CPU或GPU),避免加载时的设备不匹配问题。例如在CPU机器上加载GPU训练的模型时,必须指定map_location='cpu'

九、tqdm进度条

在长时间的训练过程中,进度条不仅能提供直观的进度反馈,更能实时展示loss、学习率、准确率等关键指标的变化。tqdm是Python生态中最流行的进度条库,与PyTorch训练循环的集成非常简洁。

9.1 tqdm 基础用法

from tqdm import tqdm # 基础用法:包裹迭代器 for i in tqdm(range(100), desc='Processing'): time.sleep(0.01) # 训练循环中集成 for epoch in range(num_epochs): model.train() pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') for batch_idx, (inputs, labels) in enumerate(pbar): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 实时更新进度条显示的指标 pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}' })

9.2 完整训练管道的最终整合

下面将前面所有内容整合成一个完整的训练管道实现

import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import CosineAnnealingLR from tqdm import tqdm import os class Trainer: """完整训练器封装""" def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device, max_grad_norm=None): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.criterion = criterion self.optimizer = optimizer self.scheduler = scheduler self.device = device self.max_grad_norm = max_grad_norm self.best_acc = 0.0 self.history = {'train_loss': [], 'val_loss': [], 'val_acc': []} def train_epoch(self): self.model.train() total_loss = 0.0 pbar = tqdm(self.train_loader, desc='Train') for inputs, labels in pbar: inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, labels) loss.backward() if self.max_grad_norm is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() total_loss += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) return total_loss / len(self.train_loader) def validate(self): self.model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(self.val_loader, desc='Val'): inputs, labels = inputs.to(self.device), labels.to(self.device) outputs = self.model(inputs) loss = self.criterion(outputs, labels) total_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() avg_loss = total_loss / len(self.val_loader) accuracy = 100.0 * correct / total return avg_loss, accuracy def fit(self, num_epochs, save_dir='checkpoints'): os.makedirs(save_dir, exist_ok=True) for epoch in range(num_epochs): print(f'\n========== Epoch {epoch+1}/{num_epochs} ==========') train_loss = self.train_epoch() val_loss, val_acc = self.validate() self.history['train_loss'].append(train_loss) self.history['val_loss'].append(val_loss) self.history['val_acc'].append(val_acc) print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') # 学习率调度 if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(val_loss) else: self.scheduler.step() # 保存最佳模型 if val_acc > self.best_acc: self.best_acc = val_acc torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_acc': self.best_acc, }, os.path.join(save_dir, 'best_model.pth')) print(f'✓ 新最佳模型保存,准确率: {val_acc:.2f}%')
# 使用示例 model = MyModel().to(device) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6) trainer = Trainer( model=model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, scheduler=scheduler, device=device, max_grad_norm=1.0 ) trainer.fit(num_epochs=50) # 训练完成后可视化 # plt.plot(trainer.history['train_loss'], label='Train Loss') # plt.plot(trainer.history['val_loss'], label='Val Loss') # plt.legend(); plt.show()

最佳实践总结:

  • 使用AdamW替代Adam,因为前者实现了正确的权重衰减解耦,泛化性能更好。
  • 启用pin_memory=True和合适的num_workers来最大化GPU利用率。
  • 对RNN/Transformer架构务必使用梯度裁剪(max_grad_norm=1.0)。
  • 验证集和训练集使用不同的数据变换(验证集不做随机增强)。
  • 使用torch.no_grad()和model.eval()确保验证结果准确可靠。
  • 保存完整checkpoint实现断点续训,仅保存state_dict用于模型部署。

十、核心要点总结

  • Dataset:Map-Style(实现__len__和__getitem__)适用于随机访问场景;Iterable-Style(实现__iter__)适用于流式数据场景。二者不可混用,DataLoader会根据Dataset类型自动选择处理方式。
  • DataLoader参数:batch_size决定batch大小;shuffle/sampler控制采样顺序;num_workers控制多进程预取并行度;collate_fn是自定义batch组装逻辑的扩展点;pin_memory加速GPU数据传输。
  • Sampler:SequentialSampler适合验证场景;RandomSampler适合训练打乱;SubsetRandomSampler适合数据集划分;WeightedRandomSampler处理类别不平衡;DistributedSampler用于分布式训练。
  • 数据增强:Compose串联多个变换;ToTensor将PIL转换为C×H×W的tensor并缩放到[0,1];Normalize执行减均值除方差标准化;RandomCrop/ColorJitter/RandomHorizontalFlip是三大经典增强手段。
  • 训练循环:zero_grad→forward→loss→backward→(clip_grad)→optimizer.step的固定顺序不可颠倒。验证循环需使用model.eval()和torch.no_grad()。
  • 模型持久化:state_dict仅保存权重(推荐部署使用);checkpoint保存完整训练状态(推荐断点续训)。使用map_location参数处理设备不匹配问题。
  • tqdm:通过set_postfix实时显示训练指标,与tqdm包裹的迭代器配合实现优雅的进度展示。建议在每个epoch的循环中使用。