通道注意力:SE-Net#

为什么CNN需要注意力机制? 中我们讨论了注意力的两个维度。先从通道注意力开始——这是最简单、最直观的注意力形式,由 SE-Net(Squeeze-and-Excitation Networks)在 2017 年提出 [HSS18]

核心思想:三个步骤的直觉#

SE-Net 的核心思想可以用三个词概括:压缩 → 激励 → 缩放

想想你在管理一个团队:你有很多个员工(通道),每个员工都有不同的专长。但你的资源有限,需要决定谁更重要。

SE-Net的三步直觉

  1. 压缩(Squeeze):你让每个员工交一份"工作总结"——用一句话概括他这周做了什么。对应到网络:把每个通道的整个特征图压缩成一个数字(全局平均池化)。

  2. 激励(Excite):你看了所有总结,判断谁的工作更重要——给每个员工分配一个"重要性分数"。对应到网络:用两个全连接层学习通道间的依赖关系,输出每个通道的权重。

  3. 缩放(Scale):按重要性分配资源——重要员工获得更多支持,不重要员工减少资源。对应到网络:把权重乘回原始特征图,重要通道被增强,不重要通道被抑制。

关键洞察:SE-Net 不改变网络的结构(通道数不变),而是改变"流经每个通道的信息量"——就像调节音量旋钮,而不是换音箱。

三步详解#

1. Squeeze:压缩#

输入是一个特征图 \(X \in \mathbb{R}^{C \times H \times W}\)\(C\) 个通道,每个 \(H \times W\))。Squeeze 操作把每个通道压缩成单个数字:

\[z_c = \frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W x_c(i,j)\]

这其实就是全局平均池化(Global Average Pooling)。为什么用平均?因为我们想获取每个通道的"全局统计信息"——不只看局部响应,而是看整个特征图的平均激活强度。

备注

直觉:如果一个通道的平均激活值很高,说明该通道编码的特征在整张图片中都很"活跃",可能很重要。反之,平均激活值低的通道可能对当前任务贡献不大。

2. Excitation:学习权重#

有了每个通道的"工作总结"\(z \in \mathbb{R}^C\),接下来要学习通道间的依赖关系:

\[s = \sigma(W_2 \cdot \delta(W_1 \cdot z))\]

其中:

  • \(W_1 \in \mathbb{R}^{C/r \times C}\):降维层,先把 \(C\) 维压缩到 \(C/r\)

  • \(\delta\):ReLU 激活函数

  • \(W_2 \in \mathbb{R}^{C \times C/r}\):升维层,恢复回 \(C\)

  • \(\sigma\):Sigmoid 激活函数,输出 \((0,1)\) 之间的权重

为什么用瓶颈结构(先降维再升维)?

瓶颈结构的设计动机

  1. 减少参数量:直接学习 \(C \times C\) 的变换需要 \(C^2\) 个参数。用瓶颈结构只需要 \(2C^2/r\) 个参数,当 \(r=16\) 时减少了 8 倍。

  2. 引入非线性:降维→ReLU→升维的结构比单层线性变换有更强的表达能力。

  3. 学习通道间关系:瓶颈迫使信息通过一个低维"瓶颈",这迫使网络学到通道间的紧凑表示。

压缩比 \(r\) 控制瓶颈的宽度:\(r\) 越大参数越少,但表达能力也越弱。通常取 \(r=16\)

3. Scale:加权#

把学习到的权重 \(s_c \in (0,1)\) 乘回原始特征图的对应通道:

\[\tilde{x}_c = s_c \cdot x_c\]

这就是一个逐通道的缩放操作。权重接近 1 的通道被保留甚至增强,权重接近 0 的通道被抑制。

PyTorch 实现#

SE模块的PyTorch实现#
 1import torch
 2import torch.nn as nn
 3
 4class SEBlock(nn.Module):
 5    """
 6    Squeeze-and-Excitation 模块
 7
 8    三步核心操作:
 9    1. Squeeze: 全局平均池化, 将 C×H×W 压缩为 C×1×1
10    2. Excitation: 两个全连接层学习通道权重 (瓶颈结构)
11    3. Scale: 权重乘回原始特征图
12
13    参数量: 2 * C * (C/r) = 2C²/r
14    输入: (B, C, H, W)
15    输出: (B, C, H, W)
16    """
17    def __init__(self, channels, reduction=16):
18        super(SEBlock, self).__init__()
19
20        # Squeeze: 全局平均池化, 每个通道压缩为一个标量
21        # 输入 (B, C, H, W) → 输出 (B, C, 1, 1)
22        self.avg_pool = nn.AdaptiveAvgPool2d(1)
23
24        # Excitation: 瓶颈结构, 压缩比 r=16 减少参数量
25        # C → C/r → C, 参数: C*(C/r) + (C/r)*C = 2C²/r
26        self.fc = nn.Sequential(
27            nn.Linear(channels, channels // reduction, bias=False),  # 降维
28            nn.ReLU(inplace=True),                                   # 非线性
29            nn.Linear(channels // reduction, channels, bias=False),  # 升维
30            nn.Sigmoid()                                             # 输出 (0,1) 权重
31        )
32
33    def forward(self, x):
34        b, c, _, _ = x.size()
35
36        # Squeeze: (B, C, H, W) → (B, C, 1, 1) → (B, C)
37        y = self.avg_pool(x).view(b, c)
38
39        # Excitation: (B, C) → (B, C/r) → (B, C) → (B, C, 1, 1)
40        y = self.fc(y).view(b, c, 1, 1)
41
42        # Scale: 逐通道乘法, 广播到 H×W
43        # 每个通道乘以对应的标量权重
44        return x * y.expand_as(x)
45
46if __name__ == "__main__":
47    x = torch.randn(2, 64, 32, 32)
48    se = SEBlock(64, reduction=16)
49    y = se(x)
50    print(f"Input shape: {x.shape}")
51    print(f"Output shape: {y.shape}")
52    print(f"SEBlock parameters: {sum(p.numel() for p in se.parameters())}  (2*64²/16 = 512)")

代码要点

  • AdaptiveAvgPool2d((1, 1)) 实现 Squeeze:把任意尺寸的特征图压缩到 \(1 \times 1\)

  • Linear 层的输入输出维度由压缩比 \(r\) 控制

  • Sigmoid 保证输出在 \((0,1)\) 范围内

  • x * self.scale 实现逐通道加权

参数量与计算开销#

SE 模块的参数量为:

\[\text{Params} = \frac{C}{r} \times C + C \times \frac{C}{r} = \frac{2C^2}{r}\]

\(C=256, r=16\) 时,参数数量为 \(2 \times 256^2 / 16 = 8,192\)——相对于 ResNet-50 的 2500 万参数可以忽略不计。计算开销增加也小于 1%。

SE-Net的关键贡献

  • 仅增加约 1% 计算量,提升 1-2% 准确率

  • 即插即用:可集成到任何CNN架构中(ResNet、MobileNet、Inception等)

  • 在 ImageNet 上,SE-ResNet-50 达到 77.62% Top-1 准确率,比基线提升 +1.47%

SE-ResNet 集成#

SE 模块通常插入到残差块中、残差连接之前:

SE-ResNet基础块实现#
 1import torch
 2import torch.nn as nn
 3from .se_block import SEBlock
 4
 5
 6class SEBasicBlock(nn.Module):
 7    """
 8    带有 SE 模块的 ResNet 基础块
 9
10    SE 模块插入在第二个卷积之后、残差连接之前:
11    Conv1 → BN → ReLU → Conv2 → BN → SE → + → ReLU
12
13                             残差连接 ─┘
14
15    输入: (B, C, H, W)
16    输出: (B, planes*expansion, H/stride, W/stride)
17    """
18    expansion = 1
19
20    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
21        super(SEBasicBlock, self).__init__()
22
23        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
24                               stride=stride, padding=1, bias=False)
25        self.bn1 = nn.BatchNorm2d(planes)
26
27        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
28                               stride=1, padding=1, bias=False)
29        self.bn2 = nn.BatchNorm2d(planes)
30
31        # SE 模块: 在第二个卷积后, 残差连接前
32        self.se = SEBlock(planes, reduction)
33
34        self.downsample = downsample
35        self.stride = stride
36        self.relu = nn.ReLU(inplace=True)
37
38    def forward(self, x):
39        identity = x
40
41        out = self.conv1(x)
42        out = self.bn1(out)
43        out = self.relu(out)
44
45        out = self.conv2(out)
46        out = self.bn2(out)
47
48        # 注意力: 在特征融合前重新校准通道重要性
49        out = self.se(out)
50
51        # 残差连接
52        if self.downsample is not None:
53            identity = self.downsample(x)
54
55        out += identity
56        out = self.relu(out)
57
58        return out
59
60if __name__ == "__main__":
61    x = torch.randn(2, 64, 32, 32)
62    block = SEBasicBlock(64, 64)
63    y = block(x)
64    print(f"Input shape: {x.shape}")
65    print(f"Output shape: {y.shape}")
66    print(f"Block params: {sum(p.numel() for p in block.parameters())}")

本章小结#

  • SE-Net 通过 Squeeze → Excitation → Scale 三步实现通道注意力

  • 核心是让网络学习"每个通道的重要性权重",然后据此重新校准特征

  • 结构简单、计算开销小、即插即用

下一步#

通道注意力解决了"什么特征重要",但还有一个问题没回答:重要的特征出现在哪里? 空间注意力 我们将学习关注"空间位置"的注意力机制。


参考文献#

[HSS18]

Jie Hu, Li Shen, and Gang Sun. Squeeze-and-excitation networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 7132–7141. 2018.

贡献者与修订历史

查看详细修订记录
  • 2231276 2026-04-28 - Heyan Zhu: feat(attention-mechanisms): restructure and enhance attention mechanisms documentation
  • 0c291d7 2025-12-10 - Heyan Zhu: docs: restructure course materials and add new content