空间注意力#
通道注意力:SE-Net 中,我们学习了通道注意力——让网络关注"什么特征重要"。但还有一个维度没覆盖:特征出现在哪里?一张猫的图片中,猫所在的空间位置比背景更重要。这就是空间注意力要解决的问题。
核心思想#
空间注意力(Spatial Attention)的目标是生成一个空间注意力图——一个与输入特征图同等宽高的二维权重图,每个位置的值表示该位置的重要性。
空间注意力的直觉
通道注意力是"给每个频道调音量",空间注意力则是**“给屏幕的每个区域调亮度”**——重要的区域(如物体位置)更亮,不重要的区域(如背景)更暗。
数学形式#
输入特征图 \(F \in \mathbb{R}^{C \times H \times W}\),空间注意力模块生成注意力图 \(M_s \in \mathbb{R}^{1 \times H \times W}\):
其中 \(\odot\) 表示逐元素乘法(注意力图广播到所有通道)。
关键问题:如何从 \(C\) 个通道的特征图生成 \(1\) 个通道的空间注意力图?答案是沿通道维度聚合信息。
CBAM风格的空间注意力#
最常用的空间注意力方法来自 CBAM [WPLK18],步骤如下:
通道池化:对每个空间位置,分别计算该位置在所有通道上的平均值和最大值,得到两个 \(1 \times H \times W\) 的聚合特征图。
拼接:把两个聚合图拼成 \(2 \times H \times W\)。
卷积:用一个 \(7 \times 7\) 卷积层将 \(2\) 个通道压缩回 \(1\) 个通道。
激活:Sigmoid 将值映射到 \((0,1)\),得到空间注意力图。
为什么同时用平均池化和最大池化?
平均池化捕获"该位置的整体激活强度"——相当于问"这个区域整体上多活跃?“。最大池化捕获"该位置的最强响应”——相当于问"这个区域最突出的特征是什么?"。两者互补,结合起来能更全面地描述每个空间位置的信息量。
1import torch
2import torch.nn as nn
3
4class SpatialAttention(nn.Module):
5 """
6 CBAM 风格的空间注意力模块
7
8 沿通道维度聚合信息, 生成空间注意力图 (1×H×W):
9 1. 通道平均池化 + 通道最大池化 → 2×H×W
10 2. 7×7 卷积压缩为 1×H×W
11 3. Sigmoid 激活
12
13 参数量: 2 * 7 * 7 = 98 (k=7 时)
14 输入: (B, C, H, W)
15 输出: (B, 1, H, W) — 空间注意力图
16 """
17 def __init__(self, kernel_size=7):
18 super(SpatialAttention, self).__init__()
19
20 assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
21 padding = 3 if kernel_size == 7 else 1
22
23 # 输入 2 通道 (avg + max), 输出 1 通道 (注意力图)
24 # 参数量: 2 * 1 * k * k = 2k²
25 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
26 self.sigmoid = nn.Sigmoid()
27
28 def forward(self, x):
29 # 沿通道维度 (dim=1) 聚合
30 # avg_out: (B, 1, H, W), 每个位置的平均激活强度
31 avg_out = torch.mean(x, dim=1, keepdim=True)
32 # max_out: (B, 1, H, W), 每个位置的最强响应
33 max_out, _ = torch.max(x, dim=1, keepdim=True)
34
35 # 拼接: (B, 2, H, W)
36 x = torch.cat([avg_out, max_out], dim=1)
37 # 卷积压缩 + Sigmoid: (B, 1, H, W)
38 x = self.conv(x)
39 return self.sigmoid(x)
40
41if __name__ == "__main__":
42 x = torch.randn(2, 64, 32, 32)
43 sa = SpatialAttention(kernel_size=7)
44 y = sa(x)
45 print(f"Input shape: {x.shape}")
46 print(f"Attention shape: {y.shape}")
47 print(f"SpatialAttention params: {sum(p.numel() for p in sa.parameters())} (2*7²=98)")
空间注意力的其他形式#
除了 CBAM 风格外,还有其他实现方式:
卷积生成:直接用 \(1 \times 1\) 卷积降维后接 \(3 \times 3\) 卷积生成注意力图。参数量更多,但可学习性更强。
自注意力(Non-local) [WGGH18]:计算所有空间位置两两之间的相似度,生成全局注意力图。计算量大(\(O(H^2W^2)\)),但能捕捉长距离依赖。
对于初学者,建议从 CBAM风格开始——它最简单、最轻量,在大多数任务上效果也足够好。
通道 vs 空间:类比总结#
维度 |
关心的问题 |
输出形状 |
生活类比 |
|---|---|---|---|
通道注意力 |
什么特征重要 |
\(C \times 1 \times 1\) |
选频道 |
空间注意力 |
哪里重要 |
\(1 \times H \times W\) |
调区域亮度 |
两者不冲突,而是互补。下一节我们将学习如何将它们结合起来。
下一步#
通道+空间注意力:CBAM 将展示如何把通道注意力和空间注意力组合成一个更强大的模块。
参考文献#
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. Non-local neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 7794–7803. 2018.
Sanghyun Woo, Jongchan Park, Joon-Young Lee, and In So Kweon. Cbam: convolutional block attention module. In Proceedings of the European Conference on Computer Vision (ECCV), 3–19. 2018.