通道+空间注意力:CBAM#
通道注意力:SE-Net 解决了"什么特征重要"(通道维度),空间注意力 解决了"哪里重要"(空间维度)。那么问题来了:能不能两个都要?
CBAM(Convolutional Block Attention Module)[WPLK18] 就是答案——它把通道注意力和空间注意力串联起来,形成一个更强大的注意力模块。
核心思想:先"选频道"再"调区域"#
CBAM 的工作流程很直观:先判断哪些通道重要(通道注意力),再判断这些重要通道中的哪些空间位置最关键(空间注意力)。
CBAM的直觉
想象你在看一张照片找猫:
通道注意力:首先,你决定看"颜色"这个维度(而不是"纹理"或"亮度"),因为猫的颜色信息最有用。
空间注意力:然后,你在"颜色"维度下,聚焦到照片中特定的区域——猫所在的位置。
这就是 CBAM 的两步走:先"选对频道",再"看对位置"。
数学形式#
CBAM 的完整计算过程为:
其中 \(\otimes\) 表示逐元素乘法。
通道注意力部分#
CBAM 的通道注意力与 SE-Net 类似,但增加了一个改进:同时使用平均池化和最大池化:
1import torch
2import torch.nn as nn
3
4class ChannelAttention(nn.Module):
5 """
6 CBAM 的通道注意力模块
7
8 与 SE-Net 的区别: 同时使用平均池化和最大池化, 然后相加融合
9 - 平均池化: 捕获全局统计信息
10 - 最大池化: 捕获最显著响应
11 - 两者互补, 提供更丰富的通道描述
12
13 使用 1×1 Conv2d 替代 Linear, 保持 4D 张量格式 (兼容性更好)
14 参数量: 2 * C * (C/r) = 2C²/r
15 输入: (B, C, H, W)
16 输出: (B, C, 1, 1)
17 """
18 def __init__(self, in_channels, reduction=16):
19 super(ChannelAttention, self).__init__()
20
21 self.avg_pool = nn.AdaptiveAvgPool2d(1)
22 self.max_pool = nn.AdaptiveMaxPool2d(1)
23
24 # 共享的 MLP, 用 1×1 Conv2d 实现 (等价于 Linear, 但保持 4D)
25 self.mlp = nn.Sequential(
26 nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), # 降维
27 nn.ReLU(inplace=True),
28 nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 升维
29 )
30
31 self.sigmoid = nn.Sigmoid()
32
33 def forward(self, x):
34 # 两种池化分别通过共享 MLP, 然后逐元素相加
35 avg_out = self.mlp(self.avg_pool(x)) # (B, C, 1, 1)
36 max_out = self.mlp(self.max_pool(x)) # (B, C, 1, 1)
37 out = avg_out + max_out # 融合
38 return self.sigmoid(out)
39
40if __name__ == "__main__":
41 x = torch.randn(2, 64, 32, 32)
42 ca = ChannelAttention(64, reduction=16)
43 y = ca(x)
44 print(f"Input shape: {x.shape}")
45 print(f"Attention shape: {y.shape}")
46 print(f"ChannelAttention params: {sum(p.numel() for p in ca.parameters())}")
空间注意力部分#
与 空间注意力 中描述的一致,使用通道池化 + \(7 \times 7\) 卷积生成空间注意力图。
1import torch
2import torch.nn as nn
3
4class SpatialAttention(nn.Module):
5 """
6 CBAM 风格的空间注意力模块
7
8 沿通道维度聚合信息, 生成空间注意力图 (1×H×W):
9 1. 通道平均池化 + 通道最大池化 → 2×H×W
10 2. 7×7 卷积压缩为 1×H×W
11 3. Sigmoid 激活
12
13 参数量: 2 * 7 * 7 = 98 (k=7 时)
14 输入: (B, C, H, W)
15 输出: (B, 1, H, W) — 空间注意力图
16 """
17 def __init__(self, kernel_size=7):
18 super(SpatialAttention, self).__init__()
19
20 assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
21 padding = 3 if kernel_size == 7 else 1
22
23 # 输入 2 通道 (avg + max), 输出 1 通道 (注意力图)
24 # 参数量: 2 * 1 * k * k = 2k²
25 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
26 self.sigmoid = nn.Sigmoid()
27
28 def forward(self, x):
29 # 沿通道维度 (dim=1) 聚合
30 # avg_out: (B, 1, H, W), 每个位置的平均激活强度
31 avg_out = torch.mean(x, dim=1, keepdim=True)
32 # max_out: (B, 1, H, W), 每个位置的最强响应
33 max_out, _ = torch.max(x, dim=1, keepdim=True)
34
35 # 拼接: (B, 2, H, W)
36 x = torch.cat([avg_out, max_out], dim=1)
37 # 卷积压缩 + Sigmoid: (B, 1, H, W)
38 x = self.conv(x)
39 return self.sigmoid(x)
40
41if __name__ == "__main__":
42 x = torch.randn(2, 64, 32, 32)
43 sa = SpatialAttention(kernel_size=7)
44 y = sa(x)
45 print(f"Input shape: {x.shape}")
46 print(f"Attention shape: {y.shape}")
47 print(f"SpatialAttention params: {sum(p.numel() for p in sa.parameters())} (2*7²=98)")
完整CBAM模块#
1import torch
2import torch.nn as nn
3from .channel_attention import ChannelAttention
4from .spatial_attention import SpatialAttention
5
6
7class CBAM(nn.Module):
8 """
9 完整的 CBAM 模块
10
11 先通道注意力 → 再空间注意力, 串联组合:
12 1. 通道注意力: 判断"什么特征重要" (C×1×1 权重)
13 2. 空间注意力: 判断"哪里重要" (1×H×W 权重)
14
15 输入: (B, C, H, W)
16 输出: (B, C, H, W)
17 """
18 def __init__(self, in_channels, reduction=16, kernel_size=7):
19 super(CBAM, self).__init__()
20
21 self.channel_attention = ChannelAttention(in_channels, reduction)
22 self.spatial_attention = SpatialAttention(kernel_size)
23
24 def forward(self, x):
25 # Step 1: 通道注意力 → 重新校准通道重要性
26 # x * channel_attention(x): 每个通道乘以其重要性权重
27 x = x * self.channel_attention(x)
28
29 # Step 2: 空间注意力 → 聚焦重要空间区域
30 # x * spatial_attention(x): 每个空间位置乘以其重要性权重
31 x = x * self.spatial_attention(x)
32
33 return x
34
35if __name__ == "__main__":
36 x = torch.randn(2, 64, 32, 32)
37 cbam = CBAM(64, reduction=16)
38 y = cbam(x)
39 print(f"Input shape: {x.shape}")
40 print(f"Output shape: {y.shape}")
41 print(f"CBAM params: {sum(p.numel() for p in cbam.parameters())}")
1import torch
2import torch.nn as nn
3from .cbam import CBAM
4
5
6class CBAMBasicBlock(nn.Module):
7 """
8 带有 CBAM 模块的 ResNet 基础块
9
10 CBAM 插入在第二个卷积之后、残差连接之前:
11 Conv1 → BN → ReLU → Conv2 → BN → CBAM → + → ReLU
12 ↑
13 残差连接 ─────────────────────────────┘
14
15 与 SEBasicBlock 的区别: CBAM 同时做通道和空间注意力
16 输入: (B, C, H, W)
17 输出: (B, planes*expansion, H/stride, W/stride)
18 """
19 expansion = 1
20
21 def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
22 super(CBAMBasicBlock, self).__init__()
23
24 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
25 stride=stride, padding=1, bias=False)
26 self.bn1 = nn.BatchNorm2d(planes)
27
28 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
29 stride=1, padding=1, bias=False)
30 self.bn2 = nn.BatchNorm2d(planes)
31
32 # CBAM: 通道注意力 + 空间注意力
33 self.cbam = CBAM(planes, reduction)
34
35 self.downsample = downsample
36 self.stride = stride
37 self.relu = nn.ReLU(inplace=True)
38
39 def forward(self, x):
40 identity = x
41
42 out = self.conv1(x)
43 out = self.bn1(out)
44 out = self.relu(out)
45
46 out = self.conv2(out)
47 out = self.bn2(out)
48
49 # 同时进行通道和空间注意力重校准
50 out = self.cbam(out)
51
52 if self.downsample is not None:
53 identity = self.downsample(x)
54
55 out += identity
56 out = self.relu(out)
57
58 return out
59
60if __name__ == "__main__":
61 x = torch.randn(2, 64, 32, 32)
62 block = CBAMBasicBlock(64, 64)
63 y = block(x)
64 print(f"Input shape: {x.shape}")
65 print(f"Output shape: {y.shape}")
66 print(f"Block params: {sum(p.numel() for p in block.parameters())}")
SE-Net vs CBAM#
对比项 |
SE-Net |
CBAM |
|---|---|---|
注意力维度 |
仅通道 |
通道 + 空间 |
通道池化方式 |
仅平均池化 |
平均 + 最大池化 |
参数量增加 |
\(2C^2/r\) |
\(2C^2/r + k^2\) |
计算开销 |
~1% |
~2% |
适用场景 |
分类任务 |
检测/分割等需要定位的任务 |
CBAM 的实验表明 [WPLK18]:在 ImageNet 上,CBAM-ResNet-50 达到 78.49% Top-1 准确率,比基线 ResNet-50(76.15%)提升 +2.34%,比 SE-ResNet-50(77.62%)额外提升 +0.87%。
为什么要先通道再空间?
CBAM 的作者实验发现先通道再空间的效果优于先空间再通道或并行。直觉上:先通过通道注意力突出重要的语义特征,再通过空间注意力定位这些特征在图像中的具体位置——这个顺序更符合人的认知习惯:“先知道看什么,再知道看哪里”。
下一步#
理解了 SE-Net、空间注意力和 CBAM 后,你可能想知道:实际项目中应该选哪个? 注意力机制的选择与应用 将通过实验数据帮你做出选择。