news 2025/12/28 10:05:52

Day50 - 预训练模型与CBAM集成

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day50 - 预训练模型与CBAM集成

1. 核心挑战

在深度学习实践中,我们经常遇到这样的问题:如何优化一个已经训练好的模型(如ResNet18)?

直接修改模型结构(如插入注意力模块)是否会破坏原有的特征提取能力?

如何制定训练策略,既能利用预训练权重,又能让新插入的模块快速学习?

2. 为什么可以将CBAM插入预训练模型?

通常认为,修改预训练模型的结构会导致权重失效。但在 ResNet 中插入 CBAM 模块是可行的,核心原因在于 CBAM 的初始化特性:

  1. 初始状态接近“直通”

CBAM 模块的最终输出计算公式为:Output = Input * Sigmoid(Attention)

在初始化阶段,卷积层和全连接层的权重接近于 0,导致 Attention 图的值也接近 0。

由于Sigmoid(0) = 0.5,因此在训练初期,CBAM 模块的操作近似于Input * 0.5

  1. 保留特征结构

Input * 0.5仅仅是对特征数值的线性缩放,完整保留了特征图的空间结构和相对关系

这保证了下游的预训练层接收到的输入仍然是结构完好的特征,而不是混乱的噪声。

因此,我们可以将 CBAM “无缝注入”到预训练 ResNet 中,而不破坏其核心能力。

3. 模型架构:ResNet18 + CBAM

我们将 CBAM 模块插入到 ResNet18 的每一个layer(残差块组)之后。

3.1 代码实现

import torch import torch.nn as nn from torchvision import models class ResNet18_CBAM(nn.Module): def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7): super().__init__() # 1. 加载预训练ResNet18 self.backbone = models.resnet18(pretrained=pretrained) # 2. 修改首层卷积以适应 CIFAR-10 的小尺寸输入 (32x32) # 原版 ResNet 针对 ImageNet (224x224),首层卷积核大且步长为2,会导致小图特征丢失 self.backbone.conv1 = nn.Conv2d( in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False ) self.backbone.maxpool = nn.Identity() # 移除 MaxPool 层,保留更多空间信息 # 3. 在每个残差块组后添加 CBAM 模块 self.cbam_layer1 = CBAM(in_channels=64, ratio=cbam_ratio, kernel_size=cbam_kernel) self.cbam_layer2 = CBAM(in_channels=128, ratio=cbam_ratio, kernel_size=cbam_kernel) self.cbam_layer3 = CBAM(in_channels=256, ratio=cbam_ratio, kernel_size=cbam_kernel) self.cbam_layer4 = CBAM(in_channels=512, ratio=cbam_ratio, kernel_size=cbam_kernel) # 4. 修改分类头 self.backbone.fc = nn.Linear(in_features=512, out_features=num_classes) def forward(self, x): # Stem 层 x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) # Layer 1 + CBAM x = self.backbone.layer1(x) x = self.cbam_layer1(x) # Layer 2 + CBAM x = self.backbone.layer2(x) x = self.cbam_layer2(x) # Layer 3 + CBAM x = self.backbone.layer3(x) x = self.cbam_layer3(x) # Layer 4 + CBAM x = self.backbone.layer4(x) x = self.cbam_layer4(x) # 分类头 x = self.backbone.avgpool(x) x = torch.flatten(x, 1) x = self.backbone.fc(x) return x

4. 训练策略:三阶段渐进式微调

为了平衡“保留预训练知识”和“学习新模块”的需求,我们采用了差异化学习率分阶段解冻的策略。可以将模型看作一个公司团队:

  • 预训练层(ResNet):资深专家,经验丰富,只需微调。
  • 新模块(CBAM/FC):新入职实习生,一张白纸,需要快速学习。

4.1 阶段一:预热实习生 (Epoch 1-5)

  • 解冻对象:仅CBAM模块和分类头 (fc)
  • 冻结对象:所有 ResNet 主干层。
  • 学习率1e-3(高学习率)。
  • 目标:让新模块快速学习如何配合预训练特征工作,建立初步的分类边界和注意力机制。

4.2 阶段二:唤醒高层专家 (Epoch 6-20)

  • 解冻对象:增加解冻layer3layer4(高层卷积)。
  • 冻结对象conv1,layer1,layer2(底层卷积)保持冻结。
  • 学习率1e-4(中等学习率)。
  • 原理高层网络学习“构图和概念”(如车轮、猫脸),与具体任务强相关,需要重新适应 CIFAR-10。而底层网络学习“笔触和纹理”(如边缘、颜色),是通用的,暂时不需要变动。

4.3 阶段三:全员协同微调 (Epoch 21-50)

  • 解冻对象:所有层。
  • 学习率1e-5(低学习率)。
  • 目标:进行端到端的全局微调,让底层特征也做微小的调整以完美适配新任务。

5. 核心训练代码

实现上述策略的关键函数:

def set_trainable_layers(model, trainable_parts): print(f"\n---> 解冻以下部分并设为可训练: {trainable_parts}") for name, param in model.named_parameters(): param.requires_grad = False # 先默认全冻结 for part in trainable_parts: if part in name: # 如果参数名包含指定部分,则解冻 param.requires_grad = True break def train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs): # ... (省略初始化代码) for epoch in range(1, epochs + 1): # --- 动态调整策略 --- if epoch == 1: print("阶段 1:训练注意力模块和分类头") set_trainable_layers(model, ["cbam", "backbone.fc"]) # 只优化 requires_grad=True 的参数 optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) elif epoch == 6: print("阶段 2:解冻高层卷积层") set_trainable_layers(model, ["cbam", "backbone.fc", "backbone.layer3", "backbone.layer4"]) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4) elif epoch == 21: print("阶段 3:全局微调") for param in model.parameters(): param.requires_grad = True optimizer = optim.Adam(model.parameters(), lr=1e-5) # ... (后续标准训练与测试循环)

6. 实验结果

通过这种精细的训练策略,模型在 CIFAR-10 数据集上最终达到了90.15%的测试准确率。这证明了:

  1. CBAM 能够有效增强特征提取能力。
  2. 分阶段微调策略能够有效防止预训练权重的破坏,同时让新模块充分收敛。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/26 9:26:05

海尔智能设备HomeAssistant集成实战指南:打造全屋智能控制中心

海尔智能设备HomeAssistant集成实战指南:打造全屋智能控制中心 【免费下载链接】haier 项目地址: https://gitcode.com/gh_mirrors/ha/haier 还在为家中不同品牌智能设备无法统一管理而困扰吗?想象一下:炎炎夏日,你希望回…

作者头像 李华
网站建设 2025/12/26 9:26:02

为什么顶尖AI团队都在抢用Open-AutoGLM插件?真相终于揭晓

第一章:为什么顶尖AI团队都在抢用Open-AutoGLM插件?真相终于揭晓近年来,Open-AutoGLM 插件在顶级人工智能研发团队中迅速走红。其核心优势在于将自然语言理解与自动化代码生成深度融合,显著提升了大模型在复杂任务中的推理效率和可…

作者头像 李华
网站建设 2025/12/26 9:24:49

ISAC技术终极指南:从零基础到实战专家的完整路径

ISAC技术终极指南:从零基础到实战专家的完整路径 【免费下载链接】Must-Reading-on-ISAC Must Reading Papers, Research Library, Open-Source Code on Integrated Sensing and Communications (aka. Joint Radar and Communications, Joint Sensing and Communica…

作者头像 李华
网站建设 2025/12/28 0:32:32

iOS自动化测试终极完整教程:从零开始掌握iOS-Tagent

iOS自动化测试终极完整教程:从零开始掌握iOS-Tagent 【免费下载链接】iOS-Tagent iOS support agent for automation 项目地址: https://gitcode.com/gh_mirrors/io/iOS-Tagent 你是否想要快速上手iOS自动化测试,却苦于复杂的配置和繁琐的步骤&am…

作者头像 李华
网站建设 2025/12/28 10:03:03

如何快速解决LangChain4j与LMStudio协议冲突:终极兼容性指南

如何快速解决LangChain4j与LMStudio协议冲突:终极兼容性指南 【免费下载链接】langchain4j langchain4j - 一个Java库,旨在简化将AI/LLM(大型语言模型)能力集成到Java应用程序中。 项目地址: https://gitcode.com/GitHub_Trendi…

作者头像 李华
网站建设 2025/12/26 9:23:27

终极SQL查询压力测试工具:SqlQueryStress完全指南

终极SQL查询压力测试工具:SqlQueryStress完全指南 【免费下载链接】SqlQueryStress SqlQueryStress 是一个用于测试 SQL Server 查询性能和负载的工具,可以生成大量的并发查询来模拟高负载场景。 通过提供连接信息和查询模板,可以执行负载测试…

作者头像 李华