通道注意力:SE-Net#
为什么CNN需要注意力机制? 中我们讨论了注意力的两个维度。先从通道注意力开始——这是最简单、最直观的注意力形式,由 SE-Net(Squeeze-and-Excitation Networks)在 2017 年提出 [HSS18]。
核心思想:三个步骤的直觉#
SE-Net 的核心思想可以用三个词概括:压缩 → 激励 → 缩放。
想想你在管理一个团队:你有很多个员工(通道),每个员工都有不同的专长。但你的资源有限,需要决定谁更重要。
SE-Net的三步直觉
压缩(Squeeze):你让每个员工交一份"工作总结"——用一句话概括他这周做了什么。对应到网络:把每个通道的整个特征图压缩成一个数字(全局平均池化)。
激励(Excite):你看了所有总结,判断谁的工作更重要——给每个员工分配一个"重要性分数"。对应到网络:用两个全连接层学习通道间的依赖关系,输出每个通道的权重。
缩放(Scale):按重要性分配资源——重要员工获得更多支持,不重要员工减少资源。对应到网络:把权重乘回原始特征图,重要通道被增强,不重要通道被抑制。
关键洞察:SE-Net 不改变网络的结构(通道数不变),而是改变"流经每个通道的信息量"——就像调节音量旋钮,而不是换音箱。
三步详解#
1. Squeeze:压缩#
输入是一个特征图 \(X \in \mathbb{R}^{C \times H \times W}\)(\(C\) 个通道,每个 \(H \times W\))。Squeeze 操作把每个通道压缩成单个数字:
这其实就是全局平均池化(Global Average Pooling)。为什么用平均?因为我们想获取每个通道的"全局统计信息"——不只看局部响应,而是看整个特征图的平均激活强度。
备注
直觉:如果一个通道的平均激活值很高,说明该通道编码的特征在整张图片中都很"活跃",可能很重要。反之,平均激活值低的通道可能对当前任务贡献不大。
2. Excitation:学习权重#
有了每个通道的"工作总结"\(z \in \mathbb{R}^C\),接下来要学习通道间的依赖关系:
其中:
\(W_1 \in \mathbb{R}^{C/r \times C}\):降维层,先把 \(C\) 维压缩到 \(C/r\) 维
\(\delta\):ReLU 激活函数
\(W_2 \in \mathbb{R}^{C \times C/r}\):升维层,恢复回 \(C\) 维
\(\sigma\):Sigmoid 激活函数,输出 \((0,1)\) 之间的权重
为什么用瓶颈结构(先降维再升维)?
瓶颈结构的设计动机
减少参数量:直接学习 \(C \times C\) 的变换需要 \(C^2\) 个参数。用瓶颈结构只需要 \(2C^2/r\) 个参数,当 \(r=16\) 时减少了 8 倍。
引入非线性:降维→ReLU→升维的结构比单层线性变换有更强的表达能力。
学习通道间关系:瓶颈迫使信息通过一个低维"瓶颈",这迫使网络学到通道间的紧凑表示。
压缩比 \(r\) 控制瓶颈的宽度:\(r\) 越大参数越少,但表达能力也越弱。通常取 \(r=16\)。
3. Scale:加权#
把学习到的权重 \(s_c \in (0,1)\) 乘回原始特征图的对应通道:
这就是一个逐通道的缩放操作。权重接近 1 的通道被保留甚至增强,权重接近 0 的通道被抑制。
PyTorch 实现#
1import torch
2import torch.nn as nn
3
4class SEBlock(nn.Module):
5 """
6 Squeeze-and-Excitation 模块
7
8 三步核心操作:
9 1. Squeeze: 全局平均池化, 将 C×H×W 压缩为 C×1×1
10 2. Excitation: 两个全连接层学习通道权重 (瓶颈结构)
11 3. Scale: 权重乘回原始特征图
12
13 参数量: 2 * C * (C/r) = 2C²/r
14 输入: (B, C, H, W)
15 输出: (B, C, H, W)
16 """
17 def __init__(self, channels, reduction=16):
18 super(SEBlock, self).__init__()
19
20 # Squeeze: 全局平均池化, 每个通道压缩为一个标量
21 # 输入 (B, C, H, W) → 输出 (B, C, 1, 1)
22 self.avg_pool = nn.AdaptiveAvgPool2d(1)
23
24 # Excitation: 瓶颈结构, 压缩比 r=16 减少参数量
25 # C → C/r → C, 参数: C*(C/r) + (C/r)*C = 2C²/r
26 self.fc = nn.Sequential(
27 nn.Linear(channels, channels // reduction, bias=False), # 降维
28 nn.ReLU(inplace=True), # 非线性
29 nn.Linear(channels // reduction, channels, bias=False), # 升维
30 nn.Sigmoid() # 输出 (0,1) 权重
31 )
32
33 def forward(self, x):
34 b, c, _, _ = x.size()
35
36 # Squeeze: (B, C, H, W) → (B, C, 1, 1) → (B, C)
37 y = self.avg_pool(x).view(b, c)
38
39 # Excitation: (B, C) → (B, C/r) → (B, C) → (B, C, 1, 1)
40 y = self.fc(y).view(b, c, 1, 1)
41
42 # Scale: 逐通道乘法, 广播到 H×W
43 # 每个通道乘以对应的标量权重
44 return x * y.expand_as(x)
45
46if __name__ == "__main__":
47 x = torch.randn(2, 64, 32, 32)
48 se = SEBlock(64, reduction=16)
49 y = se(x)
50 print(f"Input shape: {x.shape}")
51 print(f"Output shape: {y.shape}")
52 print(f"SEBlock parameters: {sum(p.numel() for p in se.parameters())} (2*64²/16 = 512)")
代码要点:
AdaptiveAvgPool2d((1, 1))实现 Squeeze:把任意尺寸的特征图压缩到 \(1 \times 1\)Linear层的输入输出维度由压缩比 \(r\) 控制Sigmoid 保证输出在 \((0,1)\) 范围内
x * self.scale实现逐通道加权
参数量与计算开销#
SE 模块的参数量为:
当 \(C=256, r=16\) 时,参数数量为 \(2 \times 256^2 / 16 = 8,192\)——相对于 ResNet-50 的 2500 万参数可以忽略不计。计算开销增加也小于 1%。
SE-Net的关键贡献
仅增加约 1% 计算量,提升 1-2% 准确率
即插即用:可集成到任何CNN架构中(ResNet、MobileNet、Inception等)
在 ImageNet 上,SE-ResNet-50 达到 77.62% Top-1 准确率,比基线提升 +1.47%
SE-ResNet 集成#
SE 模块通常插入到残差块中、残差连接之前:
1import torch
2import torch.nn as nn
3from .se_block import SEBlock
4
5
6class SEBasicBlock(nn.Module):
7 """
8 带有 SE 模块的 ResNet 基础块
9
10 SE 模块插入在第二个卷积之后、残差连接之前:
11 Conv1 → BN → ReLU → Conv2 → BN → SE → + → ReLU
12 ↑
13 残差连接 ─┘
14
15 输入: (B, C, H, W)
16 输出: (B, planes*expansion, H/stride, W/stride)
17 """
18 expansion = 1
19
20 def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
21 super(SEBasicBlock, self).__init__()
22
23 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
24 stride=stride, padding=1, bias=False)
25 self.bn1 = nn.BatchNorm2d(planes)
26
27 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
28 stride=1, padding=1, bias=False)
29 self.bn2 = nn.BatchNorm2d(planes)
30
31 # SE 模块: 在第二个卷积后, 残差连接前
32 self.se = SEBlock(planes, reduction)
33
34 self.downsample = downsample
35 self.stride = stride
36 self.relu = nn.ReLU(inplace=True)
37
38 def forward(self, x):
39 identity = x
40
41 out = self.conv1(x)
42 out = self.bn1(out)
43 out = self.relu(out)
44
45 out = self.conv2(out)
46 out = self.bn2(out)
47
48 # 注意力: 在特征融合前重新校准通道重要性
49 out = self.se(out)
50
51 # 残差连接
52 if self.downsample is not None:
53 identity = self.downsample(x)
54
55 out += identity
56 out = self.relu(out)
57
58 return out
59
60if __name__ == "__main__":
61 x = torch.randn(2, 64, 32, 32)
62 block = SEBasicBlock(64, 64)
63 y = block(x)
64 print(f"Input shape: {x.shape}")
65 print(f"Output shape: {y.shape}")
66 print(f"Block params: {sum(p.numel() for p in block.parameters())}")
本章小结#
SE-Net 通过 Squeeze → Excitation → Scale 三步实现通道注意力
核心是让网络学习"每个通道的重要性权重",然后据此重新校准特征
结构简单、计算开销小、即插即用
下一步#
通道注意力解决了"什么特征重要",但还有一个问题没回答:重要的特征出现在哪里? 空间注意力 我们将学习关注"空间位置"的注意力机制。
参考文献#
Jie Hu, Li Shen, and Gang Sun. Squeeze-and-excitation networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 7132–7141. 2018.