1. CSRNet密集人群检测入门指南
第一次接触密集人群检测时,我被商场监控画面中密密麻麻的人头震撼到了。传统目标检测方法在这里完全失效,而CSRNet却能准确统计出人数,这让我决定深入研究这个算法。CSRNet是2018年提出的经典人群密度估计模型,特别适合处理高度遮挡的密集场景,比如地铁站、演唱会现场等。
与普通目标检测不同,CSRNet不直接检测单个人体,而是通过生成密度图来估算人数。这种思路就像用热力图表示人群分布,颜色越深表示人越密集。实际测试中,在每平方米站6-7人的极端场景下,CSRNet仍能保持较高准确率。
准备环境时我推荐使用conda创建独立环境。最近帮同事配置时发现,python3.8+torch1.12+cuda11.6的组合兼容性最好。如果使用最新torch2.0,可能会遇到一些奇怪的报错,这时回退到稳定版本往往能省去很多调试时间。
2. 环境搭建与数据准备
2.1 避坑指南:环境配置
上周帮学弟配置环境时,我们花了3小时解决一个诡异的报错,最终发现是CUDA版本不匹配。这里分享我的标准配置清单:
- Ubuntu 20.04/22.04 LTS
- CUDA 11.6 + cuDNN 8.4
- Python 3.8.10
- PyTorch 1.12.1
安装时特别注意:
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch这个组合经过20+次实践验证最稳定。曾遇到有人用pip安装导致cudnn找不到的问题,建议全程用conda管理。
2.2 数据集处理技巧
ShanghaiTech数据集处理有三大坑点:
- 解压后目录结构不对:官方zip包解压后需要手动创建
part_A_final/test_data/images这样的层级 - JSON文件路径问题:建议用VS Code批量替换所有json中的路径分隔符为
/ - 缺失图片处理:IMG_280.jpg需要手动补到训练集
我写了个自动修复脚本:
import json import os def fix_json(path): with open(path) as f: data = json.load(f) for item in data: item['filename'] = item['filename'].replace('\\', '/') with open(path, 'w') as f: json.dump(data, f, indent=2)3. 模型训练实战
3.1 关键参数调优
初始训练时我的MAE高达120,远差于论文的68.2。经过两周调参,总结出这些黄金参数:
| 参数名 | 推荐值 | 作用说明 |
|---|---|---|
| batch_size | 8 | 显存不足可降至4 |
| lr | 1e-5 | 初始学习率 |
| steps | [50,100] | 学习率衰减时机 |
| scales | [0.1,0.01] | 衰减幅度 |
特别提醒:原代码的scales全是1等于没衰减!这是我踩过最大的坑。修改train.py中的这部分:
args.steps = [50, 100] # 在第50和100epoch调整学习率 args.scales = [0.1, 0.01] # 衰减为原来的0.1倍和0.01倍3.2 断点续训技巧
训练400轮需要近20小时,中断后继续训练要注意:
- 保存的checkpoint要完整(至少包含state_dict和optimizer)
- 恢复训练时加入--pre参数:
python train.py part_A_train.json part_A_test.json 0 0 --pre ./saved_models/checkpoint.pth.tar- 学习率需要重置:在load_checkpoint后添加:
for param_group in optimizer.param_groups: param_group['lr'] = args.lr # 恢复初始学习率4. 效果验证与可视化
4.1 量化评估指标
测试时发现两个关键点:
- 验证集MAE会虚高:如果验证图片包含在训练集
- 最佳模型选择:不要只看MAE,要结合可视化效果
我的评估脚本增加了标准差计算:
def evaluate(model, loader): model.eval() mae, mse = 0, 0 counts = [] with torch.no_grad(): for inputs, targets in loader: outputs = model(inputs) cnt = outputs.sum().item() gt_cnt = targets.sum().item() counts.append(abs(cnt - gt_cnt)) mae = np.mean(counts) std = np.std(counts) # 新增标准差计算 return mae, std4.2 可视化增强技巧
原始可视化代码显示效果较差,我改进后的方案:
- 增加颜色条刻度标签
- 添加预测人数标注
- 优化布局节省空间
关键修改点:
plt.figure(figsize=(18, 6)) # 预测图 ax1 = plt.subplot(1,3,1) im1 = ax1.imshow(pred_density, cmap='jet') plt.colorbar(im1, fraction=0.046, pad=0.04) ax1.set_title(f"Predicted\nCount: {pred_count:.0f}", fontsize=12) # 添加红色文字标注 ax1.text(0.5, -0.15, f"MAE: {mae:.2f}", transform=ax1.transAxes, ha='center', color='red')最终效果对比显示,改进后的可视化能同时展示原始图像、预测密度图和真实密度图,并突出显示关键指标,方便快速判断模型性能。