通道+空间注意力:CBAM#

通道注意力:SE-Net 解决了"什么特征重要"(通道维度),空间注意力 解决了"哪里重要"(空间维度)。那么问题来了:能不能两个都要?

CBAM(Convolutional Block Attention Module)[WPLK18] 就是答案——它把通道注意力和空间注意力串联起来,形成一个更强大的注意力模块。

核心思想:先"选频道"再"调区域"#

CBAM 的工作流程很直观:先判断哪些通道重要(通道注意力),再判断这些重要通道中的哪些空间位置最关键(空间注意力)。

CBAM的直觉

想象你在看一张照片找猫:

  1. 通道注意力:首先,你决定看"颜色"这个维度(而不是"纹理"或"亮度"),因为猫的颜色信息最有用。

  2. 空间注意力:然后,你在"颜色"维度下,聚焦到照片中特定的区域——猫所在的位置。

这就是 CBAM 的两步走:先"选对频道",再"看对位置"。

graph LR A[输入特征 F] --> B[通道注意力<br/>Mc] B --> C[中间特征 F'] C --> D[空间注意力<br/>Ms] D --> E[输出特征 F''] B1[Mc = σMLPAvgPoolF + MLPMaxPoolF] -.-> B D1[Ms = σf⁷×⁷AvgPoolF' ; MaxPoolF'] -.-> D

数学形式#

CBAM 的完整计算过程为:

\[F' = M_c(F) \otimes F\]
\[F'' = M_s(F') \otimes F'\]

其中 \(\otimes\) 表示逐元素乘法。

通道注意力部分#

CBAM 的通道注意力与 SE-Net 类似,但增加了一个改进:同时使用平均池化和最大池化

\[M_c(F) = \sigma(\text{MLP}(\text{AvgPool}(F)) + \text{MLP}(\text{MaxPool}(F)))\]
CBAM通道注意力模块实现#
 1import torch
 2import torch.nn as nn
 3
 4class ChannelAttention(nn.Module):
 5    """
 6    CBAM 的通道注意力模块
 7
 8    与 SE-Net 的区别: 同时使用平均池化和最大池化, 然后相加融合
 9    - 平均池化: 捕获全局统计信息
10    - 最大池化: 捕获最显著响应
11    - 两者互补, 提供更丰富的通道描述
12
13    使用 1×1 Conv2d 替代 Linear, 保持 4D 张量格式 (兼容性更好)
14    参数量: 2 * C * (C/r) = 2C²/r
15    输入: (B, C, H, W)
16    输出: (B, C, 1, 1)
17    """
18    def __init__(self, in_channels, reduction=16):
19        super(ChannelAttention, self).__init__()
20
21        self.avg_pool = nn.AdaptiveAvgPool2d(1)
22        self.max_pool = nn.AdaptiveMaxPool2d(1)
23
24        # 共享的 MLP, 用 1×1 Conv2d 实现 (等价于 Linear, 但保持 4D)
25        self.mlp = nn.Sequential(
26            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),  # 降维
27            nn.ReLU(inplace=True),
28            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)   # 升维
29        )
30
31        self.sigmoid = nn.Sigmoid()
32
33    def forward(self, x):
34        # 两种池化分别通过共享 MLP, 然后逐元素相加
35        avg_out = self.mlp(self.avg_pool(x))  # (B, C, 1, 1)
36        max_out = self.mlp(self.max_pool(x))  # (B, C, 1, 1)
37        out = avg_out + max_out                # 融合
38        return self.sigmoid(out)
39
40if __name__ == "__main__":
41    x = torch.randn(2, 64, 32, 32)
42    ca = ChannelAttention(64, reduction=16)
43    y = ca(x)
44    print(f"Input shape: {x.shape}")
45    print(f"Attention shape: {y.shape}")
46    print(f"ChannelAttention params: {sum(p.numel() for p in ca.parameters())}")

空间注意力部分#

空间注意力 中描述的一致,使用通道池化 + \(7 \times 7\) 卷积生成空间注意力图。

CBAM空间注意力模块实现#
 1import torch
 2import torch.nn as nn
 3
 4class SpatialAttention(nn.Module):
 5    """
 6    CBAM 风格的空间注意力模块
 7
 8    沿通道维度聚合信息, 生成空间注意力图 (1×H×W):
 9    1. 通道平均池化 + 通道最大池化 → 2×H×W
10    2. 7×7 卷积压缩为 1×H×W
11    3. Sigmoid 激活
12
13    参数量: 2 * 7 * 7 = 98 (k=7 时)
14    输入: (B, C, H, W)
15    输出: (B, 1, H, W)  — 空间注意力图
16    """
17    def __init__(self, kernel_size=7):
18        super(SpatialAttention, self).__init__()
19
20        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
21        padding = 3 if kernel_size == 7 else 1
22
23        # 输入 2 通道 (avg + max), 输出 1 通道 (注意力图)
24        # 参数量: 2 * 1 * k * k = 2k²
25        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
26        self.sigmoid = nn.Sigmoid()
27
28    def forward(self, x):
29        # 沿通道维度 (dim=1) 聚合
30        # avg_out: (B, 1, H, W), 每个位置的平均激活强度
31        avg_out = torch.mean(x, dim=1, keepdim=True)
32        # max_out: (B, 1, H, W), 每个位置的最强响应
33        max_out, _ = torch.max(x, dim=1, keepdim=True)
34
35        # 拼接: (B, 2, H, W)
36        x = torch.cat([avg_out, max_out], dim=1)
37        # 卷积压缩 + Sigmoid: (B, 1, H, W)
38        x = self.conv(x)
39        return self.sigmoid(x)
40
41if __name__ == "__main__":
42    x = torch.randn(2, 64, 32, 32)
43    sa = SpatialAttention(kernel_size=7)
44    y = sa(x)
45    print(f"Input shape: {x.shape}")
46    print(f"Attention shape: {y.shape}")
47    print(f"SpatialAttention params: {sum(p.numel() for p in sa.parameters())}  (2*7²=98)")

完整CBAM模块#

完整CBAM模块实现#
 1import torch
 2import torch.nn as nn
 3from .channel_attention import ChannelAttention
 4from .spatial_attention import SpatialAttention
 5
 6
 7class CBAM(nn.Module):
 8    """
 9    完整的 CBAM 模块
10
11    先通道注意力 → 再空间注意力, 串联组合:
12    1. 通道注意力: 判断"什么特征重要" (C×1×1 权重)
13    2. 空间注意力: 判断"哪里重要" (1×H×W 权重)
14    
15    输入: (B, C, H, W)
16    输出: (B, C, H, W)
17    """
18    def __init__(self, in_channels, reduction=16, kernel_size=7):
19        super(CBAM, self).__init__()
20
21        self.channel_attention = ChannelAttention(in_channels, reduction)
22        self.spatial_attention = SpatialAttention(kernel_size)
23
24    def forward(self, x):
25        # Step 1: 通道注意力 → 重新校准通道重要性
26        # x * channel_attention(x): 每个通道乘以其重要性权重
27        x = x * self.channel_attention(x)
28
29        # Step 2: 空间注意力 → 聚焦重要空间区域
30        # x * spatial_attention(x): 每个空间位置乘以其重要性权重
31        x = x * self.spatial_attention(x)
32
33        return x
34
35if __name__ == "__main__":
36    x = torch.randn(2, 64, 32, 32)
37    cbam = CBAM(64, reduction=16)
38    y = cbam(x)
39    print(f"Input shape: {x.shape}")
40    print(f"Output shape: {y.shape}")
41    print(f"CBAM params: {sum(p.numel() for p in cbam.parameters())}")
CBAM-ResNet基础块实现#
 1import torch
 2import torch.nn as nn
 3from .cbam import CBAM
 4
 5
 6class CBAMBasicBlock(nn.Module):
 7    """
 8    带有 CBAM 模块的 ResNet 基础块
 9
10    CBAM 插入在第二个卷积之后、残差连接之前:
11    Conv1 → BN → ReLU → Conv2 → BN → CBAM → + → ReLU
12
13    残差连接 ─────────────────────────────┘
14
15    与 SEBasicBlock 的区别: CBAM 同时做通道和空间注意力
16    输入: (B, C, H, W)
17    输出: (B, planes*expansion, H/stride, W/stride)
18    """
19    expansion = 1
20
21    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
22        super(CBAMBasicBlock, self).__init__()
23
24        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
25                               stride=stride, padding=1, bias=False)
26        self.bn1 = nn.BatchNorm2d(planes)
27
28        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
29                               stride=1, padding=1, bias=False)
30        self.bn2 = nn.BatchNorm2d(planes)
31
32        # CBAM: 通道注意力 + 空间注意力
33        self.cbam = CBAM(planes, reduction)
34
35        self.downsample = downsample
36        self.stride = stride
37        self.relu = nn.ReLU(inplace=True)
38
39    def forward(self, x):
40        identity = x
41
42        out = self.conv1(x)
43        out = self.bn1(out)
44        out = self.relu(out)
45
46        out = self.conv2(out)
47        out = self.bn2(out)
48
49        # 同时进行通道和空间注意力重校准
50        out = self.cbam(out)
51
52        if self.downsample is not None:
53            identity = self.downsample(x)
54
55        out += identity
56        out = self.relu(out)
57
58        return out
59
60if __name__ == "__main__":
61    x = torch.randn(2, 64, 32, 32)
62    block = CBAMBasicBlock(64, 64)
63    y = block(x)
64    print(f"Input shape: {x.shape}")
65    print(f"Output shape: {y.shape}")
66    print(f"Block params: {sum(p.numel() for p in block.parameters())}")

SE-Net vs CBAM#

对比项

SE-Net

CBAM

注意力维度

仅通道

通道 + 空间

通道池化方式

仅平均池化

平均 + 最大池化

参数量增加

\(2C^2/r\)

\(2C^2/r + k^2\)

计算开销

~1%

~2%

适用场景

分类任务

检测/分割等需要定位的任务

CBAM 的实验表明 [WPLK18]:在 ImageNet 上,CBAM-ResNet-50 达到 78.49% Top-1 准确率,比基线 ResNet-50(76.15%)提升 +2.34%,比 SE-ResNet-50(77.62%)额外提升 +0.87%

为什么要先通道再空间?

CBAM 的作者实验发现先通道再空间的效果优于先空间再通道或并行。直觉上:先通过通道注意力突出重要的语义特征,再通过空间注意力定位这些特征在图像中的具体位置——这个顺序更符合人的认知习惯:“先知道看什么,再知道看哪里”。

下一步#

理解了 SE-Net、空间注意力和 CBAM 后,你可能想知道:实际项目中应该选哪个? 注意力机制的选择与应用 将通过实验数据帮你做出选择。


参考文献#

[WPLK18] (1,2)

Sanghyun Woo, Jongchan Park, Joon-Young Lee, and In So Kweon. Cbam: convolutional block attention module. In Proceedings of the European Conference on Computer Vision (ECCV), 3–19. 2018.

贡献者与修订历史

查看详细修订记录
  • 2231276 2026-04-28 - Heyan Zhu: feat(attention-mechanisms): restructure and enhance attention mechanisms documentation
  • 0c291d7 2025-12-10 - Heyan Zhu: docs: restructure course materials and add new content