二、回归损失函数
回归任务的目标是预测连续值,如房价预测、温度预测等。回归损失函数衡量预测值与真实值之间的数值差距。
2.1 均方误差(MSE / L2 Loss)
均方误差(Mean Squared Error)是最常用的回归损失函数,计算预测值与真实值之差的平方的平均值。
MSE = (1/n) ∑i=1n (yi - ŷi)2
MSE 对较大误差施加平方惩罚,因此对离群值非常敏感。其梯度与误差成正比,误差越大梯度越大,有助于在初始阶段快速收敛。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
mse_loss = nn.MSELoss()
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
loss = mse_loss(y_pred, y_true)
print(f"MSE (nn.MSELoss): {loss.item():.4f}")
loss_f = F.mse_loss(y_pred, y_true)
print(f"MSE (F.mse_loss): {loss_f.item():.4f}")
def mse_manual(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
loss_m = mse_manual(y_true, y_pred)
print(f"MSE (manual): {loss_m.item():.4f}")
MSE 优缺点分析
- 优势: 处处可导,梯度计算简单;凸函数有全局最优解;对大误差惩罚大,收敛快
- 劣势: 对离群值极度敏感,单个离群点可能主导损失;误差较大时梯度爆炸风险
- 适用场景: 误差服从高斯分布、离群值较少、需要快速收敛的回归任务
2.2 平均绝对误差(MAE / L1 Loss)
平均绝对误差(Mean Absolute Error)计算预测值与真实值之差的绝对值的平均值。
MAE = (1/n) ∑i=1n |yi - ŷi|
MAE 对所有误差施加线性惩罚,对离群值更鲁棒。但其在误差为零处不可导,且对于小误差的梯度恒定,可能收敛较慢。
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
mae_loss = nn.L1Loss()
loss_l1 = mae_loss(y_pred, y_true)
print(f"MAE (nn.L1Loss): {loss_l1.item():.4f}")
def mae_manual(y_true, y_pred):
return torch.mean(torch.abs(y_true - y_pred))
mse_val = torch.mean((y_true - y_pred) ** 2)
mae_val = torch.mean(torch.abs(y_true - y_pred))
print(f"MSE={mse_val:.4f}, MAE={mae_val:.4f}")
y_true_outlier = torch.tensor([3.0, -0.5, 2.0, 7.0, 100.0])
y_pred_outlier = torch.tensor([2.5, 0.0, 2.0, 8.0, 5.0])
print(f"含离群值 - MSE={F.mse_loss(y_pred_outlier, y_true_outlier):.2f}")
print(f"含离群值 - MAE={F.l1_loss(y_pred_outlier, y_true_outlier):.2f}")
2.3 Huber Loss
Huber Loss 结合了 MSE 和 MAE 的优点,通过一个阈值 δ 来切换两种损失的特性。当误差小于 δ 时使用 MSE(平滑),误差大于 δ 时使用 MAE(鲁棒)。
Lδ(y, ŷ) =
½(y - ŷ)2, 当 |y - ŷ| ≤ δ
δ · |y - ŷ| - ½δ2, 当 |y - ŷ| > δ
class HuberLoss(nn.Module):
def __init__(self, delta=1.0):
super().__init__()
self.delta = delta
def forward(self, y_pred, y_true):
error = y_true - y_pred
abs_error = torch.abs(error)
quadratic = torch.clamp(abs_error, max=self.delta)
linear = abs_error - quadratic
return torch.mean(
0.5 * quadratic ** 2 + self.delta * linear
)
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
for delta in [0.5, 1.0, 2.0]:
loss_fn = HuberLoss(delta=delta)
loss_val = loss_fn(y_pred, y_true)
print(f"Huber(delta={delta}): {loss_val.item():.4f}")
huber_loss = nn.HuberLoss(delta=1.0)
loss_h = huber_loss(y_pred, y_true)
print(f"内置 Huber: {loss_h.item():.4f}")
Huber 的最佳实践:δ 是超参数,通常设为 1.0。若离群值较多,可增大 δ;若希望更接近 MSE 行为,可减小 δ。实际应用中 δ 常通过交叉验证选择。
2.4 Log-Cosh Loss
Log-Cosh Loss 是另一个平滑的回归损失,计算方式为 log(cosh(y - ŷ))。它具备 Huber Loss 的优点,且处处二阶可导,优化更稳定。
L(y, ŷ) = log(cosh(y - ŷ))
def log_cosh_loss(y_pred, y_true):
error = y_pred - y_true
return torch.mean(torch.log(torch.cosh(error)))
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
loss_lc = log_cosh_loss(y_pred, y_true)
print(f"Log-Cosh Loss: {loss_lc.item():.4f}")
def log_cosh_stable(y_pred, y_true):
error = y_pred - y_true
return torch.mean(error + F.softplus(-2.0 * error)
- torch.log(torch.tensor(2.0)))
三、二分类损失函数
二分类任务的目标是将样本分为两个类别(正类/负类),如垃圾邮件检测、疾病筛查等。
3.1 二元交叉熵(BCE Loss)
二元交叉熵(Binary Cross-Entropy)是二分类任务的标准损失函数,基于信息论中的交叉熵概念。
BCE = -(1/n) ∑i=1n [yi · log(ŷi) + (1 - yi) · log(1 - ŷi)]
bce_loss = nn.BCEWithLogitsLoss()
logits = torch.randn(4, requires_grad=True)
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])
loss = bce_loss(logits, targets)
print(f"BCEWithLogitsLoss: {loss.item():.4f}")
def bce_manual(logits, targets):
probs = torch.sigmoid(logits)
eps = 1e-12
probs = torch.clamp(probs, eps, 1.0 - eps)
return -torch.mean(
targets * torch.log(probs) +
(1.0 - targets) * torch.log(1.0 - probs)
)
pos_weight = torch.tensor([2.0])
weighted_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss_w = weighted_bce(logits, targets)
print(f"加权 BCE: {loss_w.item():.4f}")
BCE 的关键理解:
- 当真实标签 y=1 时,损失为 -log(ŷ),预测概率越接近 1 损失越小
- 当真实标签 y=0 时,损失为 -log(1-ŷ),预测概率越接近 0 损失越小
- 对错误预测的惩罚是对数级的 —— 非常自信的错误预测会受到极大惩罚
- 始终使用 BCEWithLogitsLoss 而非手动组合 Sigmoid + BCELoss,以避免数值不稳定
3.2 Hinge Loss(合页损失)
Hinge Loss 是 SVM(支持向量机)使用的损失函数,要求正确类别的分数至少比错误类别高出一个"边界"(margin)。
L(y, ŷ) = max(0, 1 - y · ŷ)
def hinge_loss(y_pred, y_true):
"""
y_true 应为 {-1, +1}
"""
return torch.mean(torch.clamp(1 - y_true * y_pred, min=0))
y_pred = torch.tensor([0.8, -0.2, 1.5, -0.7])
y_true = torch.tensor([1.0, -1.0, 1.0, -1.0])
loss_h = hinge_loss(y_pred, y_true)
print(f"Hinge Loss: {loss_h.item():.4f}")
def squared_hinge_loss(y_pred, y_true):
return torch.mean(torch.clamp(1 - y_true * y_pred, min=0) ** 2)
loss_pt = nn.HingeEmbeddingLoss(margin=1.0)
Hinge vs BCE:Hinge Loss 不仅要求分类正确,还要求正确类别的分数高于错误类别至少一个 margin。这使得 Hinge Loss 倾向于学习出"更大间隔"的决策边界,从而提高泛化能力。但 Hinge Loss 在正确分类且超过 margin 时梯度为零,可能导致"死神经元"问题。
3.3 指数损失(Exponential Loss)
指数损失是 AdaBoost 算法使用的损失函数,对错误分类施加指数级惩罚。
L(y, ŷ) = exp(-y · ŷ)
def exponential_loss(y_pred, y_true):
"""
y_true 应为 {-1, +1}
"""
return torch.mean(torch.exp(-y_true * y_pred))
y_pred = torch.tensor([2.0, -1.0, 0.5, -2.0])
y_true = torch.tensor([1.0, -1.0, 1.0, -1.0])
loss_e = exponential_loss(y_pred, y_true)
print(f"指数损失: {loss_e.item():.4f}")
注意事项:指数损失对离群值和错误标签极度敏感,一个错误标注的数据点可能导致模型严重偏离。在实际应用中,如果数据质量不佳,建议改用 Hinge Loss 或 BCE。
四、多分类损失函数
多分类任务需要将样本分到多个类别之一,如图像识别(猫/狗/鸟)、手写数字识别(0-9)等。
4.1 交叉熵损失(Cross-Entropy Loss)
交叉熵损失是多分类任务的标准损失函数,结合了 Softmax 激活函数和负对数似然。
CE = -∑c=1C yc · log(ŷc)
其中 ŷc = exp(zc) / ∑j exp(zj)
ce_loss = nn.CrossEntropyLoss()
logits = torch.randn(4, 5)
targets = torch.tensor([0, 2, 1, 3])
loss = ce_loss(logits, targets)
print(f"CrossEntropyLoss: {loss.item():.4f}")
def cross_entropy_manual(logits, targets):
exp_logits = torch.exp(logits - torch.max(logits, dim=1,
keepdim=True).values)
probs = exp_logits / torch.sum(exp_logits, dim=1, keepdim=True)
batch_size = logits.size(0)
return -torch.mean(torch.log(
probs[torch.arange(batch_size), targets] + 1e-10
))
loss_m = cross_entropy_manual(logits, targets)
print(f"手动实现: {loss_m.item():.4f}")
class_weights = torch.tensor([0.5, 1.0, 2.0, 1.0, 0.8])
weighted_ce = nn.CrossEntropyLoss(weight=class_weights)
loss_w = weighted_ce(logits, targets)
print(f"加权交叉熵: {loss_w.item():.4f}")
为什么交叉熵比 MSE 更适合分类?
- 梯度饱和:MSE + Softmax 在预测完全错误时梯度反而很小,导致学习缓慢;交叉熵在预测错误时梯度大,学习快
- 概率解释:最小化交叉熵等价于最大化似然估计,具有坚实的统计学基础
- 信息论视角:交叉熵衡量两个分布之间的差异,当预测分布完全匹配真实分布时为零
4.2 KL 散度(KL Divergence)
KL 散度(Kullback-Leibler Divergence)衡量两个概率分布 P 和 Q 之间的差异,常用于知识蒸馏和变分自编码器(VAE)。
DKL(P || Q) = ∑i P(i) · log(P(i) / Q(i))
kl_loss = nn.KLDivLoss(reduction='batchmean')
input_log_probs = F.log_softmax(torch.randn(4, 5), dim=1)
target_probs = F.softmax(torch.randn(4, 5), dim=1)
loss_k = kl_loss(input_log_probs, target_probs)
print(f"KLDivLoss: {loss_k.item():.4f}")
def kl_divergence(p, q, eps=1e-10):
"""
p: 真实分布,q: 近似分布
返回 D_KL(P || Q)
"""
p = torch.clamp(p, eps, 1.0)
q = torch.clamp(q, eps, 1.0)
return torch.sum(p * torch.log(p / q))
def distillation_loss(student_logits, teacher_logits, temperature=4.0):
"""
知识蒸馏损失:
学生模型通过 KL 散度模仿教师模型的软标签
"""
soft_student = F.log_softmax(student_logits / temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
return kd_loss * (temperature ** 2)
4.3 Categorical Hinge Loss
Categorical Hinge Loss 是将 Hinge Loss 扩展到多分类的变体,核心思想是让正确类别的分数高出所有错误类别一个 margin。
cat_hinge = nn.MultiMarginLoss(margin=1.0)
logits = torch.randn(4, 5)
targets = torch.tensor([0, 2, 1, 3])
loss_ch = cat_hinge(logits, targets)
print(f"Categorical Hinge: {loss_ch.item():.4f}")
def categorical_hinge_manual(logits, targets, margin=1.0):
batch_size = logits.size(0)
correct_scores = logits[torch.arange(batch_size), targets].unsqueeze(1)
margins = logits - correct_scores + margin
margins[torch.arange(batch_size), targets] = 0
return torch.mean(torch.clamp(margins, min=0))
loss_ch_m = categorical_hinge_manual(logits, targets)
print(f"手动 Categorical Hinge: {loss_ch_m.item():.4f}")
六、自定义损失函数
在实际项目中,标准损失函数往往无法完全满足业务需求。自定义损失函数可以融入领域知识、业务约束和特定优化目标。
自定义损失函数的设计原则
- 可微性: 损失函数必须几乎处处可导(允许有限个不可导点,如 Huber)
- 数值稳定性: 避免 exp 溢出、log(0) 等情况,使用 clamp 或数值稳定技巧
- 梯度合理: 梯度不应过大(梯度爆炸)或过小(梯度消失)
- 凸性优先: 凸损失函数更容易优化(非凸损失需要更多调参技巧)
6.1 Focal Loss(处理类别不平衡)
Focal Loss 在交叉熵基础上引入了调节因子,降低易分类样本的权重,迫使模型关注难分类样本。特别适合目标检测等类别极度不平衡的场景。
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
"""
alpha: 类别权重,平衡正负样本
gamma: 聚焦参数,gamma=0 时退化为交叉熵
"""
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
ce_loss = F.binary_cross_entropy_with_logits(
logits, targets, reduction='none'
)
probs = torch.sigmoid(logits)
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
focal_weight = focal_weight * alpha_t
return torch.mean(focal_weight * ce_loss)
focal = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(10, requires_grad=True)
targets = torch.where(torch.rand(10) > 0.9,
torch.ones(10), torch.zeros(10))
loss = focal(logits, targets)
print(f"Focal Loss: {loss.item():.4f}")
6.2 Dice Loss(图像分割常用)
Dice Loss 基于 Dice 系数(F1 Score 的集合版本),广泛用于医学图像分割任务,能有效处理前景背景极度不平衡的问题。
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, y_pred, y_true):
y_pred: 预测概率图 (B, C, H, W)
y_true: 真实标签图 (B, C, H, W)
"""
y_pred = torch.sigmoid(y_pred)
intersection = torch.sum(y_pred * y_true, dim=(2, 3))
union = torch.sum(y_pred, dim=(2, 3)) + torch.sum(y_true, dim=(2, 3))
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
return 1.0 - torch.mean(dice)
class ComboLoss(nn.Module):
def __init__(self, dice_weight=0.5, bce_weight=0.5):
super().__init__()
self.dice = DiceLoss()
self.bce = nn.BCEWithLogitsLoss()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
def forward(self, y_pred, y_true):
return (self.dice_weight * self.dice(y_pred, y_true) +
self.bce_weight * self.bce(y_pred, y_true))
6.3 分位数损失(Quantile Loss)
分位数损失用于分位数回归,可以预测目标变量的条件分位数,为预测提供不确定性估计。
def quantile_loss(y_pred, y_true, quantile=0.5):
"""
分位数损失,quantile=0.5 时退化为 MAE
quantile=0.9 时学习 90% 分位数
"""
error = y_true - y_pred
loss = torch.where(
error > 0,
quantile * error,
(quantile - 1) * error
)
return torch.mean(loss)
class MultiQuantileLoss(nn.Module):
def __init__(self, quantiles=[0.1, 0.5, 0.9]):
super().__init__()
self.quantiles = quantiles
def forward(self, y_pred, y_true):
total_loss = 0.0
for i, q in enumerate(self.quantiles):
total_loss += quantile_loss(y_pred[:, i], y_true, q)
return total_loss / len(self.quantiles)