基于Paddle2.0的注意力卷积网络CBAM和BAM

想给卷积网络添加注意力机制吗?是否已经厌倦了使用SE-NET?本项目使用Paddle2.0复现了含有注意力机制的卷积网络CBAM和BAM,并在动物分类数据集上进行了训练和验证。

项目背景
CBAM是2018年ECCV上的一篇论文CBAM: Convolutional Block Attention Module中提出的基于注意力机制的卷积网络模型。BAM是2018年BMVC上的一篇论文BAM: Bottleneck Attention Module中提出的基于注意力机制的网络模型。本项目即对其进行复现。
计算机视觉领域的注意力机制主要涵盖空间注意力和通道注意力两个方面。其中空间注意力用来捕获像素间的关系,而通道注意力用来捕获通道间的关系。CBAM提出了Convolutional Block Attention Module(CBAM)模块,该模块从空间注意力和通道注意力两个方面生成注意力特征图,然后将注意力特征图和输入进行相乘来调节注意力特征图的参数。本项目复现CBAM和BAM并用其来完成动物图像分类的实验。项目简介
本项目首次使用paddle2.0复现了含有注意力机制的网络CBAM和BAM,并在动物数据集上进行了训练和验证。
动物数据集的划分是按8:2的的划分方法进行训练集与验证集划分的。
模型简介
CBAM网络的核心思想是提出了CBAM模块。该模块对输入先经过通道注意力模块,和输入相乘后再经过空间注意力模块,和输入再次相乘后得到调整参数的注意力特征图。如图1所示。

图1 CBAM模块细节示意图
BAM网络的核心思想是提出了BAM模块。BAM可以认为是并行版的CBAM。如图2所示。

图2 BAM模块细节示意图
具体实现可以fork后见代码细节。
论文原文:CBAM: Convolutional Block Attention Module
参考代码:
PyTorch的实现
数据集介绍
本项目使用10分类的动物数据集进行训练和测试.
该十分类动物数据集,包含dog,horse,elephant,butterfly,chicken,cat,cow,sheep,spider和squirrel。每一分类的图片数量为2k-5k。
文件结构
解压数据集
In [1]!unzip -q data/data70196/animals.zip -d work/dataset登录后复制
查看图片
In [ ]import osimport randomfrom matplotlib import pyplot as pltfrom PIL import Imageimgs = []paths = os.listdir('work/dataset')for path in paths: img_path = os.path.join('work/dataset', path) if os.path.isdir(img_path): img_paths = os.listdir(img_path) img = Image.open(os.path.join(img_path, random.choice(img_paths))) imgs.append((img, path))f, ax = plt.subplots(3, 3, figsize=(12,12))for i, img in enumerate(imgs[:9]): ax[i//3, i%3].imshow(img[0]) ax[i//3, i%3].axis('off') ax[i//3, i%3].set_title('label: %s' % img[1])plt.show()登录后复制登录后复制
划分训练集和验证集
In [2]!python code/train_val_split.py登录后复制
finished train val split!登录后复制
使用CBAM-ResNet50网络进行动物分类的训练并验证
训练
In [1]!python code/train.py --net 'cbam_resnet'登录后复制
验证
In [32]!python code/eval.py --net 'cbam_resnet'登录后复制
W0218 15:48:02.818117 23045 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0218 15:48:02.824904 23045 device_context.cc:372] device: 0, cuDNN Version: 7.6.Eval begin...The loss value printed in the log is the current batch, and the metric is the average value of previous step.step 103/103 [==============================] - loss: 0.3478 - acc: 0.8544 - 232ms/step Eval samples: 3276{'loss': [0.347824], 'acc': 0.8543956043956044}登录后复制使用BAM-ResNet50网络进行动物分类的训练并验证
训练
In [31]!python code/train.py --net 'bam_resnet'登录后复制
W0218 15:48:47.528769 23145 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0218 15:48:47.532490 23145 device_context.cc:372] device: 0, cuDNN Version: 7.6.The loss value printed in the log is the current step, and the metric is the average value of previous step.Epoch 1/50登录后复制
验证
In [34]!python code/eval.py --net 'bam_resnet'登录后复制
W0218 19:49:38.340137 5185 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0218 19:49:38.343930 5185 device_context.cc:372] device: 0, cuDNN Version: 7.6.Eval begin...The loss value printed in the log is the current batch, and the metric is the average value of previous step.step 103/103 [==============================] - loss: 0.2684 - acc: 0.8504 - 199ms/step Eval samples: 3276{'loss': [0.2684111], 'acc': 0.8504273504273504}登录后复制图示训练验证过程

图3. 使用CBAM和BAM的训练验证图示
使用resnet50网络进行动物分类的训练并验证
训练
In [2]!python code/train.py --net 'resnet'登录后复制
验证
In [ ]!python code/eval.py --net 'resnet'登录后复制
W0213 21:34:50.038996 12684 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0213 21:34:50.043457 12684 device_context.cc:372] device: 0, cuDNN Version: 7.6.Eval begin...The loss value printed in the log is the current batch, and the metric is the average value of previous step.step 103/103 [==============================] - loss: 1.4232 - acc: 0.5888 - 191ms/step Eval samples: 3276{'loss': [1.4232028], 'acc': 0.5888278388278388}登录后复制图示训练验证过程

图4. 使用ResNet的训练验证图示
比较

图5. 使用CBAM、BAM和ResNet的验证比较图示
免责声明
游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。
同类文章
谷歌AI转型实绩:传统业务焕新,为互联网企业带来增长新路径
谷歌母公司Alphabet近日公布了最新季度财报,总营收达1023 46亿美元,同比增长16%,超出华尔街预期超20亿美元。分业务线来看,各板块表现均优于市场预期,摊薄每股收益达2 87美元,盘后股
马斯克“硬刚”维基百科:人类知识运营的深层矛盾解析
埃隆·马斯克近期对维基百科发起多轮公开批评,并推出由人工智能驱动的在线百科项目GrokiPedia,引发两大知识平台的隔空交锋。面对科技巨头的挑战,维基百科在最新募捐公告中以独特方式作出回应,强调其
黄仁勋、周鸿祎共论AI:是伙伴非工具,推动经济与个体升级
在近期科技界关于人工智能(AI)发展的讨论中,两位科技行业领军人物对AI本质的认知出现了高度契合的观点。英伟达创始人黄仁勋与360集团创始人周鸿祎不约而同地提出,AI不应被简单定义为技术工具,而应被
谷歌CEO:全力押注生成式AI,Gemini下载量突破65亿次
在最新公布的季度财报中,科技巨头Alphabet交出了一份亮眼成绩单,公司第三季度营收成功突破千亿美元大关。在随后召开的财报电话会议上,首席执行官桑达尔·皮查伊着重阐述了公司对生成式人工智能的战略布
环球音乐与Udio和解:版权纠纷落幕,2026年推AI音乐平台
环球音乐集团(UMG)与人工智能音乐创作平台Udio近日宣布达成一项具有开创性的战略合作协议,这一举措在音乐行业引发广泛关注。此前,双方曾因版权问题陷入法律纠纷,此次合作不仅化解了矛盾,更开启了音乐
相关攻略
热门教程
更多- 游戏攻略
- 安卓教程
- 苹果教程
- 电脑教程








