1. 半监督学习实战:从理论到代码实现
在计算机视觉领域,数据标注一直是制约模型性能提升的瓶颈。传统监督学习需要大量标注数据,而完全无监督学习又难以达到理想的分类精度。半监督学习恰好在这两者之间找到了平衡点——它能够同时利用少量标注数据和大量未标注数据来提升模型性能。
我最近在一个食品分类项目中使用半监督学习方法,仅用30%的标注数据就达到了接近全监督学习的准确率。这种方法特别适合标注成本高的场景,比如医疗影像分析、工业质检等领域。下面我将分享整个实现过程的关键细节和实战经验。
2. 半监督学习核心原理
2.1 三种学习范式对比
监督学习、无监督学习和半监督学习构成了机器学习的三驾马车,它们各有特点:
监督学习:需要完整的(X,Y)配对数据,模型学习从输入到输出的映射函数。在图像分类中,这意味着每张训练图片都必须有准确的类别标签。
无监督学习:仅使用未标注数据,通过聚类、降维等方法发现数据内在结构。典型的如K-means聚类,但它不直接解决分类问题。
半监督学习:混合使用少量标注数据和大量未标注数据。其核心假设是:
- 平滑性假设:相似样本应有相同标签
- 流形假设:数据分布在高维空间的低维流形上
- 聚类假设:同一聚类中的样本更可能共享标签
2.2 自训练算法详解
自训练(Self-training)是最直观的半监督学习方法,我的实现主要基于这种范式。其算法流程如下:
- 使用标注数据训练初始模型
- 用该模型预测未标注数据的伪标签(pseudo-label)
- 筛选高置信度预测结果加入训练集
- 用扩增后的数据集重新训练模型
- 重复2-4步直到满足停止条件
关键点在于伪标签的筛选策略。在我的代码中,只有当模型对某样本的预测置信度超过0.99时,才会将其加入训练集。这个阈值需要谨慎选择:
- 过高会导致可用数据太少
- 过低会引入噪声标签损害模型
3. 代码实现深度解析
3.1 数据准备与增强
class food_Dataset(Dataset): def __init__(self, path, mode="train"): self.mode = mode if mode == "semi": # 半监督模式只读取图像数据 self.X = self.read_file(path) else: self.X, self.Y = self.read_file(path) self.Y = torch.LongTensor(self.Y) # 数据增强策略 self.transform = train_transform if mode == "train" else val_transform数据增强是提升模型泛化能力的关键。对于训练数据,我采用了三种增强方式:
- 随机裁剪(RandomResizedCrop):模拟物体在不同位置的拍摄情况
- 随机旋转(RandomRotation):增强旋转不变性
- 标准化(ToTensor):将像素值归一化到[0,1]范围
验证集和测试集则只进行最简单的ToTensor转换,因为我们希望评估模型在真实场景下的表现。
3.2 伪标签生成机制
class semiDataset(Dataset): def get_label(self, no_label_loder, model, device, thres): model.eval() soft = nn.Softmax(dim=1) x, y = [], [] with torch.no_grad(): for bat_x, _ in no_label_loder: bat_x = bat_x.to(device) pred = model(bat_x) pred_soft = soft(pred) pred_max, pred_value = pred_soft.max(1) # 筛选高置信度样本 mask = pred_max > thres if mask.any(): x.extend([orig_img for orig_img, m in zip(batch_orig_imgs, mask) if m]) y.extend(pred_value[mask].cpu().numpy()) return x, y伪标签生成有几个技术细节值得注意:
- 必须在
torch.no_grad()上下文管理器中操作,避免不必要的梯度计算 - 使用Softmax将输出转换为概率分布,而非直接使用原始logits
- 批量处理数据而非单张处理,充分利用GPU并行计算能力
- 保留原始图像而非转换后的图像,因为后续训练会重新应用数据增强
3.3 训练流程控制
def train_val(model, train_loader, val_loader, no_label_loader, optimizer, loss, epochs, device, thres, save_path): for epoch in range(epochs): # 常规监督训练阶段 model.train() for batch_x, batch_y in train_loader: # 前向传播、损失计算、反向传播... # 半监督训练触发条件 if epoch % 3 == 0 and val_acc > 0.6: semi_loader = get_semi_loader(no_label_loader, model, device, thres) if semi_loader: # 如果有高置信度样本 for semi_x, semi_y in semi_loader: # 用伪标签数据进行训练...训练过程采用分阶段策略:
- 初始阶段:仅使用标注数据训练,直到验证准确率超过阈值(0.6)
- 半监督阶段:每3个epoch生成一次伪标签数据集
- 混合训练:同时使用原始标注数据和伪标签数据训练
这种渐进式的方法比一开始就使用伪标签更稳定,避免了早期模型预测不准导致的噪声累积问题。
4. 关键参数与调优经验
4.1 置信度阈值选择
置信度阈值thres是最敏感的参数之一。经过多次实验,我发现:
- 对于11类食品分类,0.99的阈值比较合适
- 类别数较少时(如5类),可降低到0.95
- 类别数较多时(如50类),可能需要提高到0.995
太低的阈值会导致:
- 准确率下降约15-20%
- 训练过程不稳定,损失值波动大
4.2 学习率设置
由于要同时训练标注数据和伪标签数据,学习率需要比纯监督学习更小:
- 初始学习率设为0.001
- 每5个epoch衰减为原来的0.8倍
- 使用AdamW优化器,权重衰减1e-4防止过拟合
4.3 数据比例影响
在我的实验中,标注数据和未标注数据的比例对结果影响显著:
| 标注数据比例 | 最终准确率 | 提升幅度 |
|---|---|---|
| 10% | 68.2% | +22.5% |
| 30% | 82.7% | +15.3% |
| 50% | 88.1% | +8.9% |
| 100% | 91.4% | - |
可以看到,当标注数据较少时,半监督学习带来的提升更为明显。
5. 常见问题与解决方案
5.1 伪标签质量不稳定
现象:随着训练进行,伪标签的准确率波动较大
解决方案:
- 实现动态阈值调整:初始阶段使用较高阈值,随着模型变强逐步降低
- 添加标签平滑(Label Smoothing):将硬标签转为软标签
- 使用集成学习:结合多个epoch的预测结果决定最终伪标签
5.2 类别不平衡加剧
现象:模型对某些类的预测偏好导致伪标签分布失衡
解决方法:
# 在伪标签生成阶段添加类别平衡控制 class_distribution = calculate_original_class_distribution() max_samples_per_class = int(len(unlabeled_data) * 1.5 / num_classes) for class_id in range(num_classes): class_samples = [s for s in selected_samples if s[1] == class_id] if len(class_samples) > max_samples_per_class: # 对该类样本进行随机下采样 selected_samples = [s for s in selected_samples if s not in random.sample(class_samples, len(class_samples)-max_samples_per_class)]5.3 训练时间大幅增加
现象:半监督训练比纯监督训练耗时多2-3倍
优化策略:
- 实现伪标签缓存机制,每3-5个epoch生成一次
- 使用更大的batch size进行伪标签预测(如128)
- 对未标注数据预先生成特征向量,减少重复计算
6. 进阶技巧与扩展思路
6.1 一致性正则化(Consistency Regularization)
除了自训练,还可以引入一致性损失:
def consistency_loss(model, x_unlabeled): # 对同一输入应用不同数据增强 aug1 = strong_augmentation(x_unlabeled) aug2 = strong_augmentation(x_unlabeled) # 获取两个预测结果 pred1 = model(aug1) pred2 = model(aug2) # 计算KL散度 return F.kl_div(F.log_softmax(pred1, dim=1), F.softmax(pred2, dim=1), reduction='batchmean')这种方法不依赖高置信度阈值,而是强制模型对同一图像的不同增强版本产生一致的预测。
6.2 混合专家模型(MoE)
对于复杂场景,可以训练多个专家模型:
- 用不同子集标注数据训练多个基础模型
- 每个模型为未标注数据生成伪标签
- 只保留多个模型一致预测的样本
- 用扩展数据集训练最终模型
这种方法虽然计算成本更高,但能显著提高伪标签质量。
6.3 主动学习结合
将半监督学习与主动学习结合形成闭环:
- 初始阶段使用少量标注数据训练模型
- 用模型预测未标注数据并选择最有价值的样本进行人工标注
- 将新标注数据加入训练集
- 重复2-3步直到达到性能要求
最有价值样本的选择标准可以是:
- 预测熵最大的样本(模型最不确定)
- 代表性样本(聚类中心附近)
- 多样性样本(不同聚类中的样本)
在实际项目中,我通常会先用半监督学习方法快速提升基线性能,然后再针对性地标注一些困难样本进行微调。这种组合策略能在有限标注预算下获得最佳性价比。