CBAMConvolutional Block Attention Module

摘要:
注意力机制可分为:通道注意力机制:为通道和分数生成掩码,以senet表示,通道注意力模块空间注意力机制:生成空间和分数掩码,以SpatialAttentionModule混合域注意力机制:对通道注意力和空间注意力进行评估和评分,以BAM、CBAM2表示。论文摘要ECCV2018论文地址:https://arxiv.org/abs/1807.06521本文提出了卷积注意模块,它是一种简单有效的前馈卷积神经网络注意模块。卷积块注意模块代表卷积模块的注意机制模块,它是一个结合了空间和通道的注意力机制模块。

目录

1. 前言

什么是注意力机制?
注意力机制(Attention Mechanism)是机器学习中的一种数据处理方法,广泛应用在自然语言处理、图像识别及语音识别等各种不同类型的机器学习任务中。
通俗来讲:注意力机制就是希望网络能够自动学出来图片或者文字序列中的需要注意的地方。比如人眼在看一幅画的时候,不会将注意力平等地分配给画中的所有像素,而是将更多注意力分配给人们关注的地方。
从实现的角度来讲:注意力机制通过神经网络的操作生成一个掩码mask, mask上的值一个打分,评价当前需要关注的点的评分。
注意力机制可以分为:
通道注意力机制:对通道生成掩码mask,进行打分,代表是senet, Channel Attention Module
空间注意力机制:对空间进行掩码的生成,进行打分,代表是Spatial Attention Module
混合域注意力机制:同时对通道注意力和空间注意力进行评价打分,代表的有BAM, CBAM

2.论文摘要

ECCV2018论文地址: https://arxiv.org/abs/1807.06521
文章提出了卷积注意力模块(CBAM -- Convolutional Block Attention Module ),这是一种用于前馈卷积神经网络的简单而有效的注意力模块。 给定一个中间特征图,CBAM模块会沿着两个独立的维度(通道和空间)依次推断注意力图,然后将注意力图与输入特征图相乘以进行自适应特征优化。 由于CBAM是轻量级的通用模块,因此可以忽略的该模块的开销而将其无缝集成到任何CNN架构中,并且可以与基础CNN一起进行端到端训练。 本文通过在ImageNet-1K,MS COCO检测和VOC 2007检测数据集上进行的广泛实验来验证CBAM。 实验表明,使用该模块在各种模型上,并在分类和检测性能方面的持续改进,证明了CBAM的广泛适用性。

在该论文中,作者研究了网络架构中的注意力,注意力不仅要告诉我们重点关注哪里,还要提高关注点的表示。 目标是通过使用注意机制来增加表现力,关注重要特征并抑制不必要的特征。为了强调空间和通道这两个维度上的有意义特征,作者依次应用通道和空间注意模块,来分别在通道和空间维度上学习关注什么、在哪里关注。此外,通过了解要强调或抑制的信息也有助于网络内的信息流动。
主要网络架构也很简单,一个是通道注意力模块,另一个是空间注意力模块,CBAM就是先后集成了通道注意力模块和空间注意力模块。

Convolutional Block Attention Module (CBAM) 表示卷积模块的注意力机制模块,是一种结合了空间(spatial)和通道(channel)的注意力机制模块。相比于senet只关注通道(channel)的注意力机制可以取得更好的效果。
CBAMConvolutional Block Attention Module第1张

3.通道注意力机制(Channel Attention Module)

CBAMConvolutional Block Attention Module第2张
通道注意力机制是将特征图在空间维度上进行压缩,得到一个一维矢量后再进行操作。在空间维度上进行压缩时,不仅考虑到了平均值池化(Average Pooling)还考虑了最大值池化(Max Pooling)。平均池化和最大池化可用来聚合特征映射的空间信息,送到一个共享网络,压缩输入特征图的空间维数,逐元素求和合并,以产生通道注意力图。单就一张图来说,通道注意力,关注的是这张图上哪些内容是有重要作用的。平均值池化对特征图上的每一个像素点都有反馈,而最大值池化在进行梯度反向传播计算时,只有特征图中响应最大的地方有梯度的反馈。通道注意力机制可以表达为:
CBAMConvolutional Block Attention Module第3张
代码实现:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

应用代码:

        out = self.conv2(out)
        out = self.bn2(out) #[1,64,56,56]
        ca_tmp = self.ca(out) #[1,64,1,1]
        out = self.ca(out) * out #[1,64,56,56]

4.空间注意力机制(Spatial Attention Module)

CBAMConvolutional Block Attention Module第4张
空间注意力机制是对通道进行压缩,在通道维度分别进行了平均值池化和最大值池化。MaxPool的操作就是在通道上提取最大值,提取的次数是高乘以宽;AvgPool的操作就是在通道上提取平均值,提取的次数也是是高乘以宽;接着将前面所提取到的特征图(通道数都为1)合并得到一个2通道的特征图。
CBAMConvolutional Block Attention Module第5张
其中, CBAMConvolutional Block Attention Module第6张为sigmoid操作,77表示卷积核的大小,77的卷积核比3*3的卷积核效果更好。
代码实现:

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

应用代码:

        out = self.conv2(out)
        out = self.bn2(out) #[1,64,56,56]

        sa_tmp = self.sa(out) #[1,1,56,56]
        out = self.sa(out) * out #[1,64,56,56]

通道注意力和空间注意力这两个模块能够以并行或者顺序的方式组合在一块儿,可是做者发现顺序组合而且将通道注意力放在前面能够取得更好的效果。

5.CBAM与ResNet网络结构组合

CBAMConvolutional Block Attention Module第7张

6.可视化效果图

最后,是使用Grad-cam进行了可视化,以来证明CBAM是真正地提取出了积极有效的特征。
CBAMConvolutional Block Attention Module第8张

7.代码resnet_cbam.py

最关键的部分:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out) #[1,64,56,56]

        ca_tmp = self.ca(out) #[1,64,1,1]
        sa_tmp = self.sa(out) #[1,1,56,56]

        out = self.ca(out) * out #[1,64,56,56]
        out = self.sa(out) * out #[1,64,56,56]

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

resnet_cbam.py

import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
           'resnet152_cbam']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out) #[1,64,56,56]

        ca_tmp = self.ca(out) #[1,64,1,1]
        sa_tmp = self.sa(out) #[1,1,56,56]

        out = self.ca(out) * out #[1,64,56,56]
        out = self.sa(out) * out #[1,64,56,56]

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.ca = ChannelAttention(planes * 4)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet18_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet34_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet50_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet101_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet152_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

if __name__ == '__main__':
    net = resnet18_cbam()
    input = torch.randn(1,3,224,224)
    out = net(input)
    print(out.shape)
    a = 0

免责声明:文章转载自《CBAMConvolutional Block Attention Module》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇Numerical AnalysisWebApi 部署后一直返回404的解决办法下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

随便看看

IPI 通信(SMP)【转】

在MIPS架构下的IPI通信被关闭和中断后,IPIMIPS接口结构平台也将被发送_ smp_Ops{void;void;…}IPI通信是多个处理器之间的通信。send_ ipi_Single:一对一聊天send_ ipi_Mask:Mask posting,Mask表示Mask posting/*Octeon Tellanothercore of Lushi...

TCL基本语法2

TCL基本语法21、format和scan两个基本的函数,和C语言中的sprintf和scanf的作用基本相同。format将不同类型的数据压缩在字符串中,scan将字符串中的数据提取出来。setnameJacksetage100setworker[format"%sis%dyearsold"$name$age]puts$workerscan$worker"...

mysql修改字段防止锁表

步骤1:修改大表、addcolumn或dropcolumn的字段,操作完成后将锁定该表。此时,查询ok、insert和update将等待锁定。...

说说接口封装

今天,我为同事封装了一个接口。当谈到接口封装时,有很多关于它的讨论。在很多情况下,说一个服务好,一个服务坏,实际上是在吐槽服务团队之外暴露的界面质量。无论哪种语言,抽象的封装接口都由一个函数名、几个参数和几个返回值组成。总之,参数不应该被封装……我们在内部尝试接口_Catch不会抛出异常,所有信息都将以错误代码的形式返回。就php而言,建议进行异常处理。...

flutter Radio单选框

单选框,允许用户从一组中选择一个选项。...

nginx做本地目录映射

nginx做本地目录映射有时候需要访问服务器上的一些静态资源,比如挂载其他设备上的图片到本地的目录,而本地的目录不在nginx根目录下,这个时候就需要简单的做一下目录映射来解决,比如想通过浏览器http://ip/image/2016/04/29/10/abc.jpg访问到系统目录/image_data/2016/04/29/10/abc.jpg需要在ngi...