引言

Swin Transformer(Shifted Windows Transformer)1 是微软亚洲研究院提出的层次化视觉 Transformer,通过移位窗口注意力机制有效解决了原始 ViT 的两大痛点:计算复杂度随图像尺寸平方增长、缺乏多尺度特征建模能力。

Swin Transformer 在图像分类、目标检测、语义分割等多个视觉任务上取得了 SOTA 性能,成为视觉 Transformer 领域最具影响力的架构之一。


原始 ViT 的局限性

计算复杂度问题

原始 ViT 的全局自注意力计算复杂度为 ,其中 是 Patch 数量:

图像尺寸Patch 数量全局注意力 FLOPs
224×224196轻量
384×3845763.4×
512×512102427×
1024×10244096436×

对于高分辨率图像(如检测、分割任务),计算成本急剧膨胀。

缺乏多尺度特征

ViT 使用单一尺度的 Patch Embedding,难以有效建模不同尺度的物体:

ViT: 单一尺度 (所有 token 同等对待)
      ↓
      196 个 16×16 的 Patch
      
Swin: 层次化 (多尺度特征金字塔)
      ↓
      Stage 1: 56×56 个细粒度 Patch
      Stage 2: 28×28 个聚合 Patch
      Stage 3: 14×14 个粗粒度 Patch
      Stage 4: 7×7 个高层语义 Patch

Swin Transformer 的核心设计

整体架构

Swin Transformer 采用层次化设计,模仿 CNN 的特征金字塔结构:

Input Image (224×224×3)
        │
        ▼
┌───────────────────┐
│  Patch Partition   │  划分为 4×4 Patches
│  (4×4 conv)       │  → 56×56 tokens, dim=128
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  Stage 1           │
│  ×2 Swin Blocks   │  W-MSA (窗口注意力)
│  Patch Merging     │  → 28×28 tokens, dim=256
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  Stage 2           │
│  ×2 Swin Blocks   │  W-MSA → SW-MSA (移位窗口)
│  Patch Merging     │  → 14×14 tokens, dim=512
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  Stage 3           │
│  ×6 Swin Blocks   │  W-MSA → SW-MSA
│  Patch Merging     │  → 7×7 tokens, dim=1024
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  Stage 4           │
│  ×2 Swin Blocks   │  W-MSA → SW-MSA
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  Global Avg Pool   │  输出多尺度特征
│  + Classification  │
└───────────────────┘

配置变体

配置Stage 通道数Stage 层数窗口大小参数量
Swin-TC:96, 192, 384, 7682, 2, 6, 2728M
Swin-SC:96, 192, 384, 7682, 2, 18, 2750M
Swin-BC:128, 256, 512, 10242, 2, 18, 2788M
Swin-LC:128, 256, 512, 10242, 2, 18, 27196M

窗口注意力机制(W-MSA)

核心思想

Swin 的核心创新是将全局注意力限制在非重叠的局部窗口内:

┌─────────────────────────┐
│ M×M 个 Patch 的局部窗口  │
│                         │
│  ┌───┬───┬───┬───┐      │
│  │ w1│ w2│ w3│ w4│      │
│  ├───┼───┼───┼───┤      │
│  │ w5│ w6│ w7│ w8│      │
│  ├───┼───┼───┼───┤      │
│  │ w9│w10│w11│w12│      │
│  ├───┼───┼───┼───┤      │
│  │w13│w14│w15│w16│      │
│  └───┴───┴───┴───┘      │
│                         │
│  每个窗口独立计算注意力   │
└─────────────────────────┘

计算复杂度对比

机制计算复杂度图像尺寸 224×224
全局注意力
窗口注意力

其中 是总 Patch 数, 是窗口大小(通常 ), 是特征维度。

复杂度分析

窗口注意力将复杂度从 降低到


移位窗口注意力(SW-MSA)

问题:窗口间缺乏信息交互

固定窗口导致相邻窗口之间没有信息流动,限制了模型的感受野:

W-MSA 问题:
┌───┬───┬───┬───┐
│ 1 │ 2 │ 3 │ 4 │   窗口 1 和 2 完全独立
├───┼───┼───┼───┤   边界处无信息交换
│ 5 │ 6 │ 7 │ 8 │
├───┼───┼───┼───┤
│ 9 │10 │11 │12 │
├───┼───┼───┼───┤
│13 │14 │15 │16 │
└───┴───┴───┴───┘

解决方案:交替使用移位窗口

Swin 在连续的 Transformer Block 中交替使用 W-MSA 和 SW-MSA:

Block l:     W-MSA (常规窗口)
Block l+1:   SW-MSA (移位窗口)
Block l+2:   W-MSA (常规窗口)
...

移位机制详解

Step 1: 循环移位(Cyclic Shift)

将特征图沿左上方向循环移位

原始网格:
┌───┬───┬───┬───┐
│ 1 │ 2 │ 3 │ 4 │
├───┼───┼───┼───┤
│ 5 │ 6 │ 7 │ 8 │
├───┼───┼───┼───┤
│ 9 │10 │11 │12 │
├───┼───┼───┼───┤
│13 │14 │15 │16 │
└───┴───┴───┴───┘

循环移位后 (向右下移 1):
┌───┬───┬───┬───┐
│16 │13 │14 │15 │
├───┼───┼───┼───┤
│ 4 │ 1 │ 2 │ 3 │
├───┼───┼───┼───┤
│ 8 │ 5 │ 6 │ 7 │
├───┼───┼───┼───┤
│12 │ 9 │10 │11 │
└───┴───┴───┴───┘

Step 2: 窗口分区

移位后重新划分窗口,此时:

  • 窗口 A 包含来自原图多个角落的 Patch
  • 实现了跨区域的信息交互

Step 3: 注意力掩码(Masking)

为避免不同区域Patch之间不应有的注意力,使用掩码机制:

# 注意力掩码示例
# 某些位置组合的注意力被mask为 -100,使其 softmax 后接近 0
 
attn_mask = torch.zeros((num_windows, M*M, M*M))
 
# 掩码某些不应当交互的位置
attn_mask[batch_idx][i][j] = -100.0  # mask 掉

Step 4: 逆移位

计算完成后,将结果逆移到原始位置。


PyTorch 实现

窗口注意力

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
 
class WindowAttention(nn.Module):
    """窗口多头自注意力"""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (M, M)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        # QKV 投影
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # 相对位置偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )
        
        # 获取相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
    
    def forward(self, x, mask=None):
        B_, N, C = x.shape  # B_: batch * num_windows
        
        # QKV 投影
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 缩放
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        
        # 添加相对位置偏置
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        
        # 掩码处理
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)
        
        attn = self.attn_drop(attn)
        
        # 输出投影
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Swin Transformer Block

class SwinTransformerBlock(nn.Module):
    """Swin Transformer Block with W-MSA or SW-MSA"""
    def __init__(self, dim, input_resolution, num_heads, window_size=7,
                 shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop)
        
        self.drop_path = DropPath(drop) if drop > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(dim, hidden_features=mlp_hidden_dim, drop=drop)
        
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None)
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None)
            )
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            
            mask_windows = window_partition(img_mask, window_size)
            mask_windows = mask_windows.view(-1, window_size * window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        
        self.register_buffer("attn_mask", attn_mask)
    
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W
        
        # 残差分支
        x = x + self.drop_path(self.norm1(x))
        
        # 循环移位
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        
        # 窗口划分
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        
        # W-MSA 或 SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        
        # 合并窗口
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        
        # 逆移位
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        
        x = x + self.drop_path(self.norm2(self.mlp(x)))
        return x

相对位置偏置

为什么需要位置编码

与 ViT 类似,Swin 需要位置信息来编码空间结构。但与 ViT 的绝对位置编码不同,Swin 使用相对位置偏置(Relative Position Bias)。

相对位置偏置的定义

对于窗口内位置 之间的注意力:

其中 是可学习的相对位置偏置矩阵。

偏置的物理意义

相对位置偏置作用
self-attention(相同位置)
上下相邻
左右相邻
更远的相对位置

相对位置编码使模型能够学习到”局部性”概念,自然地建模空间关系。


Patch Merging:层次化下采样

实现机制

Patch Merging 将相邻的 Patch 合并为一个:

Stage 1 输出:
┌───┬───┬───┬───┐
│ a │ b │ c │ d │
├───┼───┼───┼───┤
│ e │ f │ g │ h │
├───┼───┼───┼───┤
│ i │ j │ k │ l │
├───┼───┼───┼───┤
│ m │ n │ o │ p │
└───┴───┴───┴───┘

合并后:
┌─────────┬─────────┐
│ a b e f │ c d g h │
├─────────┼─────────┤
│ i j m n │ k l o p │
└─────────┴─────────┘

维度变化

每经过一次 Patch Merging:

  • 空间分辨率减半:
  • 通道数翻倍:
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)
    
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W
        
        x = x.view(B, H, W, C)
        
        # 划分 2x2 块并拼接
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1)
        
        x = x.view(B, -1, 4 * C)
        x = self.norm(x)
        x = self.reduction(x)
        
        return x

与 ViT 的对比

架构对比

特性ViTSwin Transformer
Patch 划分固定(16×16)固定(4×4)→ 层次化
特征尺度单一多尺度金字塔
注意力机制全局局部窗口 + 移位
位置编码绝对位置编码相对位置偏置
特征图等尺寸逐层下采样
ImageNet 精度77.9% (B/16)83.1% (B)

计算复杂度对比

图像尺寸ViT (全局注意力)Swin (窗口注意力)
224×22437.6B FLOPs15.3B FLOPs
384×384123.6B FLOPs52.9B FLOPs
512×512219.7B FLOPs94.1B FLOPs

多尺度特征的优势

Swin 的多尺度特征更适合下游任务:

分类: 7×7 特征图 → 全局池化
                ↓
检测: C3/C4/C5 多尺度特征 → FPN
                ↓
分割: 多尺度 2x, 4x, 8x, 16x → UperNet

下游任务扩展

目标检测:Swin + Mask R-CNN

Swin 作为检测器的 Backbone,结合 FPN:

Swin-T/S/B/L
    │
    ├─── Stage 2 (C2, 56×56) ─→ FPN P5
    ├─── Stage 3 (C3, 28×28) ─→ FPN P4
    ├─── Stage 4 (C4, 14×14) ─→ FPN P3
    └─── Stage 5 (C5, 7×7)   ─→ FPN P2

COCO 目标检测性能:

BackboneBox APMask AP参数量
ResNet-5038.034.444M
Swin-T46.041.648M
Swin-S48.543.369M
Swin-B49.544.3107M

语义分割:UperNet + Swin

Swin backbone
    │
    ├─── S2 (256 dim) ─→ PSP Module ─┐
    ├─── S3 (512 dim) ────────────────┤
    ├─── S4 (1024 dim) ───────────────┤
    └─── S5 (1024 dim) ───────────────┤
                                        ↓
                               Semantic Segmentation Head

ADE20K 语义分割性能:

BackbonemIoU参数量
ResNet-10144.162M
Swin-T45.853M
Swin-S47.669M
Swin-B48.1107M

变体与发展

MViT:多尺度 Vision Transformer

MViT 通过渐进式通道扩展和分辨率缩减实现多尺度建模:

Stage 1: 96 channels, 56×56 resolution
Stage 2: 192 channels, 28×28 resolution
Stage 3: 384 channels, 14×14 resolution
Stage 4: 768 channels, 7×7 resolution

SwinV2

Swin Transformer V22 的改进:

  1. 连续相对位置偏置:替代离散偏置,支持更高分辨率
  2. Log-Spaced 位置编码:更好地外推到不同输入尺寸
  3. Swin-L V2:3 亿参数的超大模型

Swin-UNet

将 Swin 的注意力机制引入 U-Net 风格的分割网络:

  • Encoder: Swin Blocks + Patch Merging
  • Decoder: Swin Blocks + Patch Expanding
  • Skip connections: 保留多尺度特征

总结

Swin Transformer 的核心贡献

贡献说明
层次化设计模仿 CNN 的特征金字塔结构
移位窗口注意力 复杂度,支持高分辨率输入
跨窗口信息交互通过移位和掩码实现
相对位置偏置更强的空间建模能力

适用场景

场景推荐理由
图像分类高精度、强特征提取
目标检测多尺度特征适合 FPN/RetinaNet
语义分割层次化特征支持高分辨率输出
实例分割Mask R-CNN 的优秀 Backbone
视频理解可扩展到时空维度

参考

Footnotes

  1. Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ICCV 2021. https://arxiv.org/abs/2103.14030

  2. Liu, Z., et al. (2022). Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022.