news 2026/4/23 23:20:45

Day49 - CBAM注意力机制

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day49 - CBAM注意力机制

1. 简介

CBAM (Convolutional Block Attention Module) 是一种轻量级的注意力模块,它可以无缝集成到任何CNN架构中,通过引入额外的开销来显著提升模型的性能。

与SE (Squeeze-and-Excitation) 模块主要关注通道注意力不同,CBAM 同时结合了通道注意力 (Channel Attention)空间注意力 (Spatial Attention)

这种串联的注意力机制使得网络能够依次学习"关注什么" (What to focus on) 和 "关注哪里" (Where to focus on)。

2. 核心原理

CBAM 包含两个子模块,通常采用串联方式连接(先通道后空间):

2.1 通道注意力模块 (Channel Attention Module, CAM)

通道注意力旨在探索通道之间的依赖关系。CBAM 的通道注意力改进了 SE 模块:

  • 不仅使用全局平均池化 (Average Pooling),还引入了全局最大池化 (Max Pooling)。
  • 认为最大池化能收集到更独特的对象特征,与平均池化互补。
  • 两个池化后的特征向量共享同一个多层感知机 (MLP) 网络。
  • 最终将两个输出相加并通过 Sigmoid 激活函数生成通道权重。

2.2 空间注意力模块 (Spatial Attention Module, SAM)

空间注意力旨在探索特征图在空间维度上的重要性(即哪些区域更重要)。

  • 在通道维度上进行平均池化和最大池化,得到两个 2D 特征图。
  • 将这两个特征图在通道维度拼接 (Concat)。
  • 通过一个 7x7 的卷积层进行特征融合。
  • 通过 Sigmoid 激活函数生成空间权重图。

3. 代码实现

以下是基于 PyTorch 的 CBAM 完整实现,包括通道注意力、空间注意力及其在 CNN 中的集成。

3.1 通道注意力 (ChannelAttention)

import torch import torch.nn as nn class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio=16): super().__init__() # 平均池化和最大池化 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # 共享的全连接层 (MLP) # 使用1x1卷积代替全连接层,减少参数量并保持输入形状 self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // ratio, bias=False), nn.ReLU(), nn.Linear(in_channels // ratio, in_channels, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): b, c, h, w = x.shape # 平均池化分支 avg_out = self.fc(self.avg_pool(x).view(b, c)) # 最大池化分支 max_out = self.fc(self.max_pool(x).view(b, c)) # 结果相加后经过Sigmoid attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1) # 权重作用于原特征图 return x * attention

3.2 空间注意力 (SpatialAttention)

class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() # padding计算保证输出大小不变 padding = kernel_size // 2 # 输入通道为2 (AvgPool 1 + MaxPool 1) self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # 在通道维度上求平均 (b, 1, h, w) avg_out = torch.mean(x, dim=1, keepdim=True) # 在通道维度上求最大 (b, 1, h, w) max_out, _ = torch.max(x, dim=1, keepdim=True) # 拼接 (b, 2, h, w) pool_out = torch.cat([avg_out, max_out], dim=1) # 卷积 + Sigmoid attention = self.conv(pool_out) return x * self.sigmoid(attention)

3.3 CBAM 模块组合

class CBAM(nn.Module): def __init__(self, in_channels, ratio=16, kernel_size=7): super().__init__() self.channel_attention = ChannelAttention(in_channels, ratio) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): # 串联结构:先通道后空间 x = self.channel_attention(x) x = self.spatial_attention(x) return x

3.4 集成到 CNN 模型

在经典的卷积神经网络中,CBAM 模块通常被放置在卷积层和激活函数之后,或者池化层之前。以下是一个简单的 CBAM-CNN 示例:

class CBAM_CNN(nn.Module): def __init__(self): super().__init__() # 第一层卷积块 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.cbam1 = CBAM(in_channels=32) # 集成CBAM # ... 后续层省略 ... # 假设这里还有更多层 # 全连接层 self.fc1 = nn.Linear(128 * 4 * 4, 512) self.dropout = nn.Dropout(p=0.5) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = self.cbam1(x) # 应用注意力机制 # ... 后续前向传播 ... x = x.view(-1, 128 * 4 * 4) x = self.fc1(x) x = self.relu1(x) # 注意这里应该是对应的激活函数 x = self.dropout(x) x = self.fc2(x) return x

4. 训练与实验

在 CIFAR-10 数据集上的训练过程显示,引入 CBAM 后,模型能够更有效地聚焦于图像的关键特征。

  • 优化器: 使用 Adam 优化器,自适应调整学习率。
  • 学习率调度: 使用ReduceLROnPlateau,当验证集损失不再下降时自动降低学习率,有助于模型收敛到更优解。
  • 性能: 在约 50 个 Epoch 的训练中,模型能够达到较高的准确率 (如 86% 左右),证明了注意力机制对特征提取能力的增强作用。

5. 总结

CBAM 通过结合通道注意力和空间注意力,提供了一种即插即用的性能提升方案。

  • 轻量级: 参数量和计算量增加很少。
  • 通用性: 适用于各种 CNN 架构 (ResNet, MobileNet 等)。
  • 互补性: MaxPool 和 AvgPool 的结合保留了更丰富的特征信息。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 14:28:03

GeoJSON.io地理数据编辑工具:完整免费使用指南

GeoJSON.io地理数据编辑工具:完整免费使用指南 【免费下载链接】geojson.io A quick, simple tool for creating, viewing, and sharing spatial data 项目地址: https://gitcode.com/gh_mirrors/ge/geojson.io 还在寻找简单高效的在线地理数据处理方案吗&am…

作者头像 李华
网站建设 2026/4/19 1:29:42

如何用AI工具,把文献综述从“耗时费力”变成“高效产出”?

如果你是一名研究生,大概率对“文献综述”这四个字有着复杂的感情。它既是开启研究课题的基石,又是学术道路上第一道令人望而生畏的关卡。从茫茫文献海中确定方向、梳理脉络、归纳观点,再到组织成文,这个过程往往意味着数周甚至数…

作者头像 李华
网站建设 2026/4/18 17:59:46

Python通达信数据解析完整指南:快速掌握二进制文件读取技巧

Python通达信数据解析完整指南:快速掌握二进制文件读取技巧 【免费下载链接】mootdx 通达信数据读取的一个简便使用封装 项目地址: https://gitcode.com/GitHub_Trending/mo/mootdx 通达信作为国内主流的证券分析平台,其高效的二进制数据格式为金…

作者头像 李华
网站建设 2026/4/22 16:41:30

2.3 运算符详解

文章目录前言一、算术运算符二、比较运算符三、逻辑运算符四、赋值运算符五、成员运算符六、运算符优先级前言 依次讲解了算数运算符、比较运算符、逻辑运算符、赋值运算符、成员运算符和运算符优先级等知识点。 一、算术运算符 用于基本的数学运算。 运算符名称示例结果说明…

作者头像 李华
网站建设 2026/4/20 7:40:47

3.1 字符串(String)

文章目录前言一、字符串创建与基本操作1. 创建字符串2. 字符串基本操作二、字符串索引与切片1. 索引(Indexing)2. 切片(Slicing)三、字符串常用方法1. 查找与替换方法2. 大小写转换3. 分割与连接4. 去除空白字符5. 判断方法&#…

作者头像 李华
网站建设 2026/4/23 14:44:51

如何快速掌握数据抓取:同花顺问财Python工具完整指南

如何快速掌握数据抓取:同花顺问财Python工具完整指南 【免费下载链接】pywencai 获取同花顺问财数据 项目地址: https://gitcode.com/gh_mirrors/py/pywencai 想要轻松获取同花顺问财的股票数据吗?pywencai作为一款专业的Python数据抓取工具&…

作者头像 李华