从交叉熵到对比学习:用PyTorch代码拆解损失函数进化史
在深度学习领域,损失函数就像导航系统的指南针,决定了模型优化的方向。但很多开发者对损失函数的理解停留在"调用现成接口"的阶段,尤其是当面对对比学习中的InfoNCE时,常常感到一头雾水。今天我们不谈复杂的数学推导,而是用PyTorch代码作为显微镜,带你观察从交叉熵(CE)到噪声对比估计(NCE),再到信息噪声对比估计(InfoNCE)的演化轨迹。
1. 交叉熵:监督学习的基石
想象你正在训练一个猫狗分类器。每次预测后,模型需要知道自己错得有多离谱——这就是交叉熵的工作。本质上,它衡量的是模型预测概率分布与真实分布的差距。
import torch import torch.nn.functional as F # 假设我们有3个样本,每个样本有5个类别的预测logits logits = torch.randn(3, 5) # 未归一化的预测值 labels = torch.tensor([1, 0, 4]) # 真实类别索引 # PyTorch实现交叉熵 ce_loss = F.cross_entropy(logits, labels) print(f"Cross Entropy Loss: {ce_loss.item():.4f}")交叉熵的核心公式其实很简单:
CE = -log(exp(s_y) / ∑exp(s_i))其中s_y是目标类别的得分。这个看似简单的设计却有几个关键特性:
- 梯度友好:错误预测时梯度较大,随着预测准确度提高梯度减小
- 概率解释:通过softmax将logits转化为概率分布
- 类别竞争:每个类别的概率相互制约,总和为1
但在无监督场景下,交叉熵遇到了两个致命问题:
- 需要明确的标签定义
- 当类别数量巨大时(如语言模型中的词汇表),softmax分母计算成本过高
这就引出了我们需要讨论的下一个主角——NCE。
2. 噪声对比估计:从概率匹配到二分类判别
NCE的核心思想很巧妙:与其直接计算概率分布,不如训练模型区分"真实数据"和"噪声"。这就像教小朋友认识苹果时,不是直接定义"苹果是什么",而是通过对比苹果和非苹果(香蕉、橘子等)来建立认知。
def nce_loss(data_samples, noise_samples, model, k=1): """ data_samples: 真实数据样本的特征向量 noise_samples: 噪声样本的特征向量 model: 特征提取模型 k: 每个真实样本对应的噪声样本数 """ # 计算数据样本和噪声样本的得分 data_scores = model(data_samples) noise_scores = model(noise_samples) # 构造联合得分向量 joint_scores = torch.cat([ data_scores, noise_scores ], dim=0) # 创建标签:数据样本为1,噪声样本为0 labels = torch.cat([ torch.ones_like(data_scores), torch.zeros_like(noise_scores) ]) # 使用二元交叉熵 return F.binary_cross_entropy_with_logits(joint_scores, labels)NCE的创新点在于:
- 计算效率:将O(|V|)的softmax计算转化为O(k)的二分类问题
- 理论保证:当噪声分布接近真实分布时,模型学习到的密度比是渐进一致的
- 灵活性:可以自由设计噪声分布,适应不同场景
但NCE仍然有其局限性——它本质上还是一个判别式模型,没有充分利用样本间的相对关系。这为InfoNCE的出现埋下了伏笔。
3. InfoNCE:对比学习的灵魂
对比学习的核心思想是"物以类聚"——相似样本在特征空间中应该靠近,不相似的应该远离。InfoNCE将这个思想数学化,成为SimCLR、MoCo等经典对比学习框架的核心损失函数。
让我们用PyTorch实现一个简化版的InfoNCE:
def info_nce_loss(query, positive_key, negative_keys, temperature=0.1): """ query: 查询样本特征 [batch_size, feature_dim] positive_key: 正样本特征 [batch_size, feature_dim] negative_keys: 负样本特征 [num_negatives, feature_dim] temperature: 温度系数 """ batch_size = query.size(0) feature_dim = query.size(1) num_negatives = negative_keys.size(0) # 归一化特征向量 query = F.normalize(query, dim=1) positive_key = F.normalize(positive_key, dim=1) negative_keys = F.normalize(negative_keys, dim=1) # 计算正样本相似度 pos_sim = torch.sum(query * positive_key, dim=1) # [batch_size] # 计算负样本相似度 neg_sim = torch.mm(query, negative_keys.t()) # [batch_size, num_negatives] # 合并相似度 logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) / temperature # 标签:第一个位置是正样本 labels = torch.zeros(batch_size, dtype=torch.long).to(query.device) return F.cross_entropy(logits, labels)这个实现揭示了InfoNCE的几个关键设计:
温度系数τ:控制相似度分布的尖锐程度
- τ越小,分布越尖锐,对困难负样本关注越多
- τ越大,分布越平缓,训练更稳定但区分度降低
负样本数量:更多的负样本提供更强的梯度信号,但也增加计算成本
特征归一化:强制特征分布在单位球面上,避免特征范数影响相似度计算
4. 从理论到实践:对比学习框架中的InfoNCE
理解了InfoNCE的基本原理后,我们来看它如何在真实对比学习框架中发挥作用。以SimCLR为例:
class SimCLR(nn.Module): def __init__(self, base_encoder, feature_dim=128, temperature=0.5): super().__init__() self.temperature = temperature self.encoder = base_encoder self.projector = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.ReLU(), nn.Linear(feature_dim, feature_dim) ) def forward(self, x1, x2): # 两个增强视图的特征 h1 = self.encoder(x1) h2 = self.encoder(x2) # 投影头 z1 = self.projector(h1) z2 = self.projector(h2) # 计算InfoNCE损失 loss = 0.5 * (self.info_nce(z1, z2) + self.info_nce(z2, z1)) return loss def info_nce(self, anchor, targets): batch_size = anchor.size(0) labels = torch.arange(batch_size).to(anchor.device) # 计算所有样本间的相似度 logits = torch.mm(anchor, targets.t()) / self.temperature # 对角线元素是正样本对 return F.cross_entropy(logits, labels)在这个实现中,有几个值得注意的工程细节:
- 投影头设计:在编码器后添加小型MLP,将特征映射到更适合对比学习的空间
- 对称损失:计算两个增强视图互为锚点的损失并取平均
- 批量负采样:同一批次中的其他样本自然作为负样本
温度系数τ的选择对SimCLR性能有显著影响。实践中,τ通常设置在0.05到0.2之间,需要根据具体任务进行调整。