news 2026/4/19 23:19:16

别再死记硬背损失函数了!从CE到InfoNCE,我用PyTorch代码带你理清对比学习的核心脉络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记硬背损失函数了!从CE到InfoNCE,我用PyTorch代码带你理清对比学习的核心脉络

从交叉熵到对比学习:用PyTorch代码拆解损失函数进化史

在深度学习领域,损失函数就像导航系统的指南针,决定了模型优化的方向。但很多开发者对损失函数的理解停留在"调用现成接口"的阶段,尤其是当面对对比学习中的InfoNCE时,常常感到一头雾水。今天我们不谈复杂的数学推导,而是用PyTorch代码作为显微镜,带你观察从交叉熵(CE)到噪声对比估计(NCE),再到信息噪声对比估计(InfoNCE)的演化轨迹。

1. 交叉熵:监督学习的基石

想象你正在训练一个猫狗分类器。每次预测后,模型需要知道自己错得有多离谱——这就是交叉熵的工作。本质上,它衡量的是模型预测概率分布与真实分布的差距。

import torch import torch.nn.functional as F # 假设我们有3个样本,每个样本有5个类别的预测logits logits = torch.randn(3, 5) # 未归一化的预测值 labels = torch.tensor([1, 0, 4]) # 真实类别索引 # PyTorch实现交叉熵 ce_loss = F.cross_entropy(logits, labels) print(f"Cross Entropy Loss: {ce_loss.item():.4f}")

交叉熵的核心公式其实很简单:

CE = -log(exp(s_y) / ∑exp(s_i))

其中s_y是目标类别的得分。这个看似简单的设计却有几个关键特性:

  • 梯度友好:错误预测时梯度较大,随着预测准确度提高梯度减小
  • 概率解释:通过softmax将logits转化为概率分布
  • 类别竞争:每个类别的概率相互制约,总和为1

但在无监督场景下,交叉熵遇到了两个致命问题:

  1. 需要明确的标签定义
  2. 当类别数量巨大时(如语言模型中的词汇表),softmax分母计算成本过高

这就引出了我们需要讨论的下一个主角——NCE。

2. 噪声对比估计:从概率匹配到二分类判别

NCE的核心思想很巧妙:与其直接计算概率分布,不如训练模型区分"真实数据"和"噪声"。这就像教小朋友认识苹果时,不是直接定义"苹果是什么",而是通过对比苹果和非苹果(香蕉、橘子等)来建立认知。

def nce_loss(data_samples, noise_samples, model, k=1): """ data_samples: 真实数据样本的特征向量 noise_samples: 噪声样本的特征向量 model: 特征提取模型 k: 每个真实样本对应的噪声样本数 """ # 计算数据样本和噪声样本的得分 data_scores = model(data_samples) noise_scores = model(noise_samples) # 构造联合得分向量 joint_scores = torch.cat([ data_scores, noise_scores ], dim=0) # 创建标签:数据样本为1,噪声样本为0 labels = torch.cat([ torch.ones_like(data_scores), torch.zeros_like(noise_scores) ]) # 使用二元交叉熵 return F.binary_cross_entropy_with_logits(joint_scores, labels)

NCE的创新点在于:

  • 计算效率:将O(|V|)的softmax计算转化为O(k)的二分类问题
  • 理论保证:当噪声分布接近真实分布时,模型学习到的密度比是渐进一致的
  • 灵活性:可以自由设计噪声分布,适应不同场景

但NCE仍然有其局限性——它本质上还是一个判别式模型,没有充分利用样本间的相对关系。这为InfoNCE的出现埋下了伏笔。

3. InfoNCE:对比学习的灵魂

对比学习的核心思想是"物以类聚"——相似样本在特征空间中应该靠近,不相似的应该远离。InfoNCE将这个思想数学化,成为SimCLR、MoCo等经典对比学习框架的核心损失函数。

让我们用PyTorch实现一个简化版的InfoNCE:

def info_nce_loss(query, positive_key, negative_keys, temperature=0.1): """ query: 查询样本特征 [batch_size, feature_dim] positive_key: 正样本特征 [batch_size, feature_dim] negative_keys: 负样本特征 [num_negatives, feature_dim] temperature: 温度系数 """ batch_size = query.size(0) feature_dim = query.size(1) num_negatives = negative_keys.size(0) # 归一化特征向量 query = F.normalize(query, dim=1) positive_key = F.normalize(positive_key, dim=1) negative_keys = F.normalize(negative_keys, dim=1) # 计算正样本相似度 pos_sim = torch.sum(query * positive_key, dim=1) # [batch_size] # 计算负样本相似度 neg_sim = torch.mm(query, negative_keys.t()) # [batch_size, num_negatives] # 合并相似度 logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) / temperature # 标签:第一个位置是正样本 labels = torch.zeros(batch_size, dtype=torch.long).to(query.device) return F.cross_entropy(logits, labels)

这个实现揭示了InfoNCE的几个关键设计:

  1. 温度系数τ:控制相似度分布的尖锐程度

    • τ越小,分布越尖锐,对困难负样本关注越多
    • τ越大,分布越平缓,训练更稳定但区分度降低
  2. 负样本数量:更多的负样本提供更强的梯度信号,但也增加计算成本

  3. 特征归一化:强制特征分布在单位球面上,避免特征范数影响相似度计算

4. 从理论到实践:对比学习框架中的InfoNCE

理解了InfoNCE的基本原理后,我们来看它如何在真实对比学习框架中发挥作用。以SimCLR为例:

class SimCLR(nn.Module): def __init__(self, base_encoder, feature_dim=128, temperature=0.5): super().__init__() self.temperature = temperature self.encoder = base_encoder self.projector = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.ReLU(), nn.Linear(feature_dim, feature_dim) ) def forward(self, x1, x2): # 两个增强视图的特征 h1 = self.encoder(x1) h2 = self.encoder(x2) # 投影头 z1 = self.projector(h1) z2 = self.projector(h2) # 计算InfoNCE损失 loss = 0.5 * (self.info_nce(z1, z2) + self.info_nce(z2, z1)) return loss def info_nce(self, anchor, targets): batch_size = anchor.size(0) labels = torch.arange(batch_size).to(anchor.device) # 计算所有样本间的相似度 logits = torch.mm(anchor, targets.t()) / self.temperature # 对角线元素是正样本对 return F.cross_entropy(logits, labels)

在这个实现中,有几个值得注意的工程细节:

  1. 投影头设计:在编码器后添加小型MLP,将特征映射到更适合对比学习的空间
  2. 对称损失:计算两个增强视图互为锚点的损失并取平均
  3. 批量负采样:同一批次中的其他样本自然作为负样本

温度系数τ的选择对SimCLR性能有显著影响。实践中,τ通常设置在0.05到0.2之间,需要根据具体任务进行调整。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/19 23:16:07

【算法日记】Day 20 动态规划专题——状态压缩DP(三)

Abstract:#动态规划 #状压DP #TSP问题 1. 题目 题目:Luogu P1171 售货员的难题核心思路:状态压缩动态规划。定义dp[status][cur]表示当前已经访问过的城市集合为status,且当前位于城市cur,要访问完所有剩余城市并最终…

作者头像 李华
网站建设 2026/4/19 23:09:16

从技术黑箱到法律可溯:2026奇点大会强制推行的AGI“行为日志双签名”标准(含ISO/IEC 27001-AI附录草案)

第一章:2026奇点智能技术大会:AGI的法律框架 2026奇点智能技术大会(https://ml-summit.org) 全球AGI治理共识的里程碑 2026奇点智能技术大会首次将通用人工智能(AGI)的法律人格认定、责任归属与跨司法管辖区监管协同列为最高优先…

作者头像 李华