news 2026/6/10 1:09:43

对比tensorflow,从0开始学pytorch(五)--CBAM

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
对比tensorflow,从0开始学pytorch(五)--CBAM

CBAM = 通道注意力(两种SENet--GAP+GMP的组合)+空间注意力

CBAM是深度学习里程碑式的产物,但代码非常简单,其实就是一个概念:给模型增加可训练可学习的参数矩阵。

有了SENet的经验,CBAM1个小时就搞定了,很丝滑,pytorch还有有一定优势的,代码写熟了以后可以快速复用。先上CBAM原论文图:

上图是总流程图,原文中做了一堆实验,一堆数据,不用管,记住结论就行:先通道注意,后空间注意,效果最好。其实也很好理解。对于隐层,先挑选出哪些隐层最值得关注(通道注意力);然后再对挑出的隐层内容进行重点内容挑选(空间注意力)。

一、通道注意力

从概念上理解后,就是两个注意力机制逐一实现的问题了。首先看通道注意力机制:

是不是特别熟悉?对比一下SENet的图:

SENet的中间展开,就是Fex(., w)展开:

这玩儿意和CBAM的通道注意力中间的那块不能说是一模一样,简直毫无差别…………

二、空间注意力

依然,线上原图:

思路和通道注意力一样,都是Max+Avg,然后通过sigmoid得到一个可以训练的加权的矩阵,然后这个加权的矩阵再和所有隐层做乘法就行。

三、填坑

写代码的时候,发现网上的参考代码居然有问题(不知道是不是我自己写的问题,但有的是明确有问题的),如下:

1. 网上“空间注意力”的代码写错了,导致百度AI给出的代码也是错的,具体如下:

2. 通道注意力机制,没有做尺度变化

这个问题我不确定,反正按照SENet来写,一定要做尺度变化,不然我这会报错。没直接运行网上的代码,感觉有问题。

3. 通道注意力机制默认的7*7卷积核确定比3*3要好么?

此处做了两个修改,一是将7*7的大卷积核改为了3*3,padding不用去算,默认是3,改为same即可。

因为后续技术发展已证明3*3的卷积核是主流,所以这里还是修改一下为好。

四:结果

的确有点用处,大部分都能到99%,之前都是98.8x%上下,有一点点提升。

附上我修改,并且确定可用的CBAM代码,如下:

import torch import torch.nn as nn import torchsummary class ChannelAttention(nn.Module): def __init__(self, input_channels:int, ratio=4): super().__init__() self.gap = nn.AdaptiveAvgPool2d(1) self.gmp = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Linear(input_channels, input_channels // ratio) self.fc2 = nn.Linear(input_channels // ratio, input_channels) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): gap_weight = self.gap(x) gap_weight = gap_weight.view(-1, x.shape[1]) gap_weight = self.fc1(gap_weight) gap_weight = self.relu(gap_weight) gap_weight = self.fc2(gap_weight) gmp_weight = self.gmp(x) gmp_weight = gmp_weight.view(-1, x.shape[1]) gmp_weight = self.fc1(gmp_weight) gmp_weight = self.relu(gmp_weight) gmp_weight = self.fc2(gmp_weight) out_put = self.sigmoid(gap_weight + gmp_weight) return out_put class SpatialAttention(nn.Module): def __init__(self): super().__init__() self.conv2d = nn.Conv2d(2,1,3,1,padding="same") self.sigmoid = nn.Sigmoid() def forward(self, x): avg_weight = torch.mean(x, dim=1, keepdim=True) # print(avg_weight.shape) max_weight = torch.max(x, dim=1,keepdim=True)[0] # print(avg_weight.shape) out_put = torch.cat((avg_weight, max_weight), dim=1) # print(out_put.shape) out_put = self.conv2d(out_put) # print(out_put.shape) out_put = self.sigmoid(out_put) return out_put class CBAM(nn.Module): def __init__(self, channels): super().__init__() self.ChannelAttention = ChannelAttention(channels) self.SpatioAttention = SpatialAttention() def forward(self, x): out_put = self.ChannelAttention(x) out_put = out_put.view(out_put.shape[0], out_put.shape[1], 1, 1) out_put = out_put * x out_put = self.SpatioAttention(out_put) * out_put return out_put # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # CBAM = CBAM(28).to(device) # torchsummary.summary(CBAM, input_size=(28,28,28))

用起来就很简单了,任何一个隐层后面都可以直接加入:

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/9 22:08:39

AutoGPT运行资源消耗测试:需要多少GPU显存?

AutoGPT运行资源消耗测试:需要多少GPU显存? 在当前AI技术快速演进的背景下,大型语言模型(LLM)正从被动应答工具向具备自主决策能力的智能体转型。像AutoGPT这样的开源项目,已经能够基于一个简单目标——比如…

作者头像 李华
网站建设 2026/6/9 13:05:24

椭圆曲线的“加法”群规则

这四个式子是在讲椭圆曲线的“加法”群规则(chord-and-tangent)。核心口诀是: 同一条直线与椭圆曲线的三个交点(按重数计算)相加等于 0(单位元) 也就是:若直线与曲线交于 A,B,C,则 A+B+C=0。 这里的 0(图里写 0)指的是无穷远点 O,是加法单位元。 同时,点的相反数是…

作者头像 李华
网站建设 2026/6/9 11:49:26

支持多模型接入的LobeChat,如何实现低成本高回报的Token售卖?

支持多模型接入的LobeChat,如何实现低成本高回报的Token售卖? 在AI应用爆发式增长的今天,越来越多企业开始尝试将大语言模型(LLM)集成到自己的产品中。然而,直接调用闭源API成本高昂,而自建系统…

作者头像 李华
网站建设 2026/6/9 22:17:22

【ROS 2】ROS 2 机器人操作系统简介 ( 概念简介 | DDS 数据分发服务 | ROS 2 版本 | Humble 文档 | ROS 2 生态简介 )

文章目录一、ROS 简介1、概念简介2、通信框架对比选择3、ROS 架构4、DDS 数据分发服务 简介二、ROS 2 版本1、ROS 2 发布版本2、ROS 2 版本文档3、Humble Hawksbill 版本 ROS 2 文档① 文档主页② 安装文档③ 教程文档④ 文档指南⑤ 概念术语三、ROS 2 生态简介1、ROS 2 通信机…

作者头像 李华
网站建设 2026/6/7 22:51:13

网络协议TCP

网络编程TCPTCP的核心特点:面向字节流(UDP是数据报),所有的读写的基本单位都是byteServerSocket:专门给服务器使用的,负责连接,不对数据进行操作Socket:服务器和客户端都可以使用当服…

作者头像 李华
网站建设 2026/6/9 19:37:27

重庆市大学生信息安全竞赛部分writeup

免责声明:本文章发布于比赛正式结束后,不存在提前泄露比赛信息及违规泄露wp的情况,作者不对读者基于本文内容而产生的任何行为或后果承担责任。如有任何侵权问题,请联系作者删除。 WEB5 传一句话木马,dirsearch扫出来…

作者头像 李华