引言
Swin Transformer(Shifted Windows Transformer)1 是微软亚洲研究院提出的层次化视觉 Transformer,通过移位窗口注意力机制有效解决了原始 ViT 的两大痛点:计算复杂度随图像尺寸平方增长、缺乏多尺度特征建模能力。
Swin Transformer 在图像分类、目标检测、语义分割等多个视觉任务上取得了 SOTA 性能,成为视觉 Transformer 领域最具影响力的架构之一。
原始 ViT 的局限性
计算复杂度问题
原始 ViT 的全局自注意力计算复杂度为 ,其中 是 Patch 数量:
| 图像尺寸 | Patch 数量 | 全局注意力 FLOPs |
|---|---|---|
| 224×224 | 196 | 轻量 |
| 384×384 | 576 | 3.4× |
| 512×512 | 1024 | 27× |
| 1024×1024 | 4096 | 436× |
对于高分辨率图像(如检测、分割任务),计算成本急剧膨胀。
缺乏多尺度特征
ViT 使用单一尺度的 Patch Embedding,难以有效建模不同尺度的物体:
ViT: 单一尺度 (所有 token 同等对待)
↓
196 个 16×16 的 Patch
Swin: 层次化 (多尺度特征金字塔)
↓
Stage 1: 56×56 个细粒度 Patch
Stage 2: 28×28 个聚合 Patch
Stage 3: 14×14 个粗粒度 Patch
Stage 4: 7×7 个高层语义 Patch
Swin Transformer 的核心设计
整体架构
Swin Transformer 采用层次化设计,模仿 CNN 的特征金字塔结构:
Input Image (224×224×3)
│
▼
┌───────────────────┐
│ Patch Partition │ 划分为 4×4 Patches
│ (4×4 conv) │ → 56×56 tokens, dim=128
└───────────────────┘
│
▼
┌───────────────────┐
│ Stage 1 │
│ ×2 Swin Blocks │ W-MSA (窗口注意力)
│ Patch Merging │ → 28×28 tokens, dim=256
└───────────────────┘
│
▼
┌───────────────────┐
│ Stage 2 │
│ ×2 Swin Blocks │ W-MSA → SW-MSA (移位窗口)
│ Patch Merging │ → 14×14 tokens, dim=512
└───────────────────┘
│
▼
┌───────────────────┐
│ Stage 3 │
│ ×6 Swin Blocks │ W-MSA → SW-MSA
│ Patch Merging │ → 7×7 tokens, dim=1024
└───────────────────┘
│
▼
┌───────────────────┐
│ Stage 4 │
│ ×2 Swin Blocks │ W-MSA → SW-MSA
└───────────────────┘
│
▼
┌───────────────────┐
│ Global Avg Pool │ 输出多尺度特征
│ + Classification │
└───────────────────┘
配置变体
| 配置 | Stage 通道数 | Stage 层数 | 窗口大小 | 参数量 |
|---|---|---|---|---|
| Swin-T | C:96, 192, 384, 768 | 2, 2, 6, 2 | 7 | 28M |
| Swin-S | C:96, 192, 384, 768 | 2, 2, 18, 2 | 7 | 50M |
| Swin-B | C:128, 256, 512, 1024 | 2, 2, 18, 2 | 7 | 88M |
| Swin-L | C:128, 256, 512, 1024 | 2, 2, 18, 2 | 7 | 196M |
窗口注意力机制(W-MSA)
核心思想
Swin 的核心创新是将全局注意力限制在非重叠的局部窗口内:
┌─────────────────────────┐
│ M×M 个 Patch 的局部窗口 │
│ │
│ ┌───┬───┬───┬───┐ │
│ │ w1│ w2│ w3│ w4│ │
│ ├───┼───┼───┼───┤ │
│ │ w5│ w6│ w7│ w8│ │
│ ├───┼───┼───┼───┤ │
│ │ w9│w10│w11│w12│ │
│ ├───┼───┼───┼───┤ │
│ │w13│w14│w15│w16│ │
│ └───┴───┴───┴───┘ │
│ │
│ 每个窗口独立计算注意力 │
└─────────────────────────┘
计算复杂度对比
| 机制 | 计算复杂度 | 图像尺寸 224×224 |
|---|---|---|
| 全局注意力 | ||
| 窗口注意力 |
其中 是总 Patch 数, 是窗口大小(通常 ), 是特征维度。
复杂度分析:
窗口注意力将复杂度从 降低到 。
移位窗口注意力(SW-MSA)
问题:窗口间缺乏信息交互
固定窗口导致相邻窗口之间没有信息流动,限制了模型的感受野:
W-MSA 问题:
┌───┬───┬───┬───┐
│ 1 │ 2 │ 3 │ 4 │ 窗口 1 和 2 完全独立
├───┼───┼───┼───┤ 边界处无信息交换
│ 5 │ 6 │ 7 │ 8 │
├───┼───┼───┼───┤
│ 9 │10 │11 │12 │
├───┼───┼───┼───┤
│13 │14 │15 │16 │
└───┴───┴───┴───┘
解决方案:交替使用移位窗口
Swin 在连续的 Transformer Block 中交替使用 W-MSA 和 SW-MSA:
Block l: W-MSA (常规窗口)
Block l+1: SW-MSA (移位窗口)
Block l+2: W-MSA (常规窗口)
...
移位机制详解
Step 1: 循环移位(Cyclic Shift)
将特征图沿左上方向循环移位 :
原始网格:
┌───┬───┬───┬───┐
│ 1 │ 2 │ 3 │ 4 │
├───┼───┼───┼───┤
│ 5 │ 6 │ 7 │ 8 │
├───┼───┼───┼───┤
│ 9 │10 │11 │12 │
├───┼───┼───┼───┤
│13 │14 │15 │16 │
└───┴───┴───┴───┘
循环移位后 (向右下移 1):
┌───┬───┬───┬───┐
│16 │13 │14 │15 │
├───┼───┼───┼───┤
│ 4 │ 1 │ 2 │ 3 │
├───┼───┼───┼───┤
│ 8 │ 5 │ 6 │ 7 │
├───┼───┼───┼───┤
│12 │ 9 │10 │11 │
└───┴───┴───┴───┘
Step 2: 窗口分区
移位后重新划分窗口,此时:
- 窗口 A 包含来自原图多个角落的 Patch
- 实现了跨区域的信息交互
Step 3: 注意力掩码(Masking)
为避免不同区域Patch之间不应有的注意力,使用掩码机制:
# 注意力掩码示例
# 某些位置组合的注意力被mask为 -100,使其 softmax 后接近 0
attn_mask = torch.zeros((num_windows, M*M, M*M))
# 掩码某些不应当交互的位置
attn_mask[batch_idx][i][j] = -100.0 # mask 掉Step 4: 逆移位
计算完成后,将结果逆移到原始位置。
PyTorch 实现
窗口注意力
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class WindowAttention(nn.Module):
"""窗口多头自注意力"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # (M, M)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# QKV 投影
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# 相对位置偏置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
# 获取相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
B_, N, C = x.shape # B_: batch * num_windows
# QKV 投影
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 缩放
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 添加相对位置偏置
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
# 掩码处理
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = F.softmax(attn, dim=-1)
else:
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
# 输出投影
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return xSwin Transformer Block
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block with W-MSA or SW-MSA"""
def __init__(self, dim, input_resolution, num_heads, window_size=7,
shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop)
self.drop_path = DropPath(drop) if drop > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(dim, hidden_features=mlp_hidden_dim, drop=drop)
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W
# 残差分支
x = x + self.drop_path(self.norm1(x))
# 循环移位
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 窗口划分
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# W-MSA 或 SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask)
# 合并窗口
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# 逆移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x相对位置偏置
为什么需要位置编码
与 ViT 类似,Swin 需要位置信息来编码空间结构。但与 ViT 的绝对位置编码不同,Swin 使用相对位置偏置(Relative Position Bias)。
相对位置偏置的定义
对于窗口内位置 和 之间的注意力:
其中 是可学习的相对位置偏置矩阵。
偏置的物理意义
| 相对位置 | 偏置作用 |
|---|---|
| self-attention(相同位置) | |
| 上下相邻 | |
| 左右相邻 | |
| 更远的相对位置 |
相对位置编码使模型能够学习到”局部性”概念,自然地建模空间关系。
Patch Merging:层次化下采样
实现机制
Patch Merging 将相邻的 Patch 合并为一个:
Stage 1 输出:
┌───┬───┬───┬───┐
│ a │ b │ c │ d │
├───┼───┼───┼───┤
│ e │ f │ g │ h │
├───┼───┼───┼───┤
│ i │ j │ k │ l │
├───┼───┼───┼───┤
│ m │ n │ o │ p │
└───┴───┴───┴───┘
合并后:
┌─────────┬─────────┐
│ a b e f │ c d g h │
├─────────┼─────────┤
│ i j m n │ k l o p │
└─────────┴─────────┘
维度变化
每经过一次 Patch Merging:
- 空间分辨率减半:
- 通道数翻倍:
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = nn.LayerNorm(4 * dim)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W
x = x.view(B, H, W, C)
# 划分 2x2 块并拼接
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], dim=-1)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x)
return x与 ViT 的对比
架构对比
| 特性 | ViT | Swin Transformer |
|---|---|---|
| Patch 划分 | 固定(16×16) | 固定(4×4)→ 层次化 |
| 特征尺度 | 单一 | 多尺度金字塔 |
| 注意力机制 | 全局 | 局部窗口 + 移位 |
| 位置编码 | 绝对位置编码 | 相对位置偏置 |
| 特征图 | 等尺寸 | 逐层下采样 |
| ImageNet 精度 | 77.9% (B/16) | 83.1% (B) |
计算复杂度对比
| 图像尺寸 | ViT (全局注意力) | Swin (窗口注意力) |
|---|---|---|
| 224×224 | 37.6B FLOPs | 15.3B FLOPs |
| 384×384 | 123.6B FLOPs | 52.9B FLOPs |
| 512×512 | 219.7B FLOPs | 94.1B FLOPs |
多尺度特征的优势
Swin 的多尺度特征更适合下游任务:
分类: 7×7 特征图 → 全局池化
↓
检测: C3/C4/C5 多尺度特征 → FPN
↓
分割: 多尺度 2x, 4x, 8x, 16x → UperNet
下游任务扩展
目标检测:Swin + Mask R-CNN
Swin 作为检测器的 Backbone,结合 FPN:
Swin-T/S/B/L
│
├─── Stage 2 (C2, 56×56) ─→ FPN P5
├─── Stage 3 (C3, 28×28) ─→ FPN P4
├─── Stage 4 (C4, 14×14) ─→ FPN P3
└─── Stage 5 (C5, 7×7) ─→ FPN P2
COCO 目标检测性能:
| Backbone | Box AP | Mask AP | 参数量 |
|---|---|---|---|
| ResNet-50 | 38.0 | 34.4 | 44M |
| Swin-T | 46.0 | 41.6 | 48M |
| Swin-S | 48.5 | 43.3 | 69M |
| Swin-B | 49.5 | 44.3 | 107M |
语义分割:UperNet + Swin
Swin backbone
│
├─── S2 (256 dim) ─→ PSP Module ─┐
├─── S3 (512 dim) ────────────────┤
├─── S4 (1024 dim) ───────────────┤
└─── S5 (1024 dim) ───────────────┤
↓
Semantic Segmentation Head
ADE20K 语义分割性能:
| Backbone | mIoU | 参数量 |
|---|---|---|
| ResNet-101 | 44.1 | 62M |
| Swin-T | 45.8 | 53M |
| Swin-S | 47.6 | 69M |
| Swin-B | 48.1 | 107M |
变体与发展
MViT:多尺度 Vision Transformer
MViT 通过渐进式通道扩展和分辨率缩减实现多尺度建模:
Stage 1: 96 channels, 56×56 resolution
Stage 2: 192 channels, 28×28 resolution
Stage 3: 384 channels, 14×14 resolution
Stage 4: 768 channels, 7×7 resolution
SwinV2
Swin Transformer V22 的改进:
- 连续相对位置偏置:替代离散偏置,支持更高分辨率
- Log-Spaced 位置编码:更好地外推到不同输入尺寸
- Swin-L V2:3 亿参数的超大模型
Swin-UNet
将 Swin 的注意力机制引入 U-Net 风格的分割网络:
- Encoder: Swin Blocks + Patch Merging
- Decoder: Swin Blocks + Patch Expanding
- Skip connections: 保留多尺度特征
总结
Swin Transformer 的核心贡献
| 贡献 | 说明 |
|---|---|
| 层次化设计 | 模仿 CNN 的特征金字塔结构 |
| 移位窗口注意力 | 复杂度,支持高分辨率输入 |
| 跨窗口信息交互 | 通过移位和掩码实现 |
| 相对位置偏置 | 更强的空间建模能力 |
适用场景
| 场景 | 推荐理由 |
|---|---|
| 图像分类 | 高精度、强特征提取 |
| 目标检测 | 多尺度特征适合 FPN/RetinaNet |
| 语义分割 | 层次化特征支持高分辨率输出 |
| 实例分割 | Mask R-CNN 的优秀 Backbone |
| 视频理解 | 可扩展到时空维度 |
参考
Footnotes
-
Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ICCV 2021. https://arxiv.org/abs/2103.14030 ↩
-
Liu, Z., et al. (2022). Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022. ↩