引言

Video Swin Transformer1 是微软亚洲研究院将Swin Transformer扩展到视频领域的工作,提出**3D移位窗口注意力(3D Shifted Window Attention)**机制,在视频理解和动作识别任务上取得了显著进展。该工作由Swin Transformer的原班人马完成,保持了Swin Transformer的层次化设计和高效注意力计算,同时将建模能力扩展到时间维度。

本文深入解析Video Swin Transformer的核心设计,包括3D移位窗口注意力的数学原理、层次化架构、与其他视频Transformer的对比,以及实验结果分析。


Video Swin Transformer概述

论文信息

  • 标题:《Video Swin Transformer》
  • 作者:Ze Liu, Jia Ning, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin, Han Hu
  • 机构:Microsoft Research Asia
  • 会议:CVPR 2022

核心思想

Video Swin Transformer的核心贡献是将Swin Transformer的2D移位窗口注意力扩展到3D时空域,同时保持层次化结构设计:

Swin Transformer (2D):
- 移位窗口注意力 (Shifted Window Attention)
- 层次化特征金字塔
- 高效的局部注意力计算

Video Swin Transformer (3D):
- 3D移位窗口注意力 (Shifted Window Attention in 3D)
- 时空层次化特征金字塔
- 联合建模时空依赖

与TimeSformer的关键区别

方面TimeSformerVideo Swin Transformer
注意力类型分解放置注意力3D移位窗口注意力
窗口设计全局稀疏局部窗口 + 移位
计算复杂度
层次化有(4个stage)
多尺度

3D移位窗口注意力

2D移位窗口注意力回顾

Swin Transformer的2D移位窗口注意力将特征图划分为不重叠的窗口,每个窗口内独立计算注意力:

标准窗口注意力:
┌─────┬─────┬─────┐
│ W11 │ W12 │ W13 │
├─────┼─────┼─────┤
│ W21 │ W22 │ W23 │    每个窗口包含 M×M 个patches
├─────┼─────┼─────┤
│ W31 │ W32 │ W33 │
└─────┴─────┴─────┘

注意力计算: 仅在窗口内进行
W11注意力 → W11内
W12注意力 → W12内
...

3D窗口划分

Video Swin Transformer将3D视频划分为3D窗口:

每个窗口 包含 个时空位置(patches),其中:

  • :时间维度的窗口大小
  • :高度维度的窗口大小
  • :宽度维度的窗口大小

典型配置:,即每个窗口包含 个时空patches。

3D窗口内注意力

对于窗口 ,窗口内注意力定义为:

其中

窗口移位机制

与Swin Transformer相同,Video Swin Transformer采用交替使用规则窗口和移位窗口的策略:

Stage 1: 规则窗口注意力 (W-MSA)
┌─────┬─────┬─────┐
│ W11 │ W12 │ W13 │
├─────┼─────┼─────┤
│ W21 │ W22 │ W23 │
├─────┼─────┼─────┤
│ W31 │ W32 │ W33 │
└─────┴─────┴─────┘

Stage 2: 移位窗口注意力 (SW-MSA)
┌─────┬─────┬─────┤ ← 移位偏移量 (⌊P/2⌋, ⌊P/2⌋)
│SW11 │SW12 │SW13│
├─────┼─────┼─────┤
│SW21 │SW22 │SW23│
├─────┼─────┼─────┤
│SW31 │SW32 │SW33│
└─────┴─────┴─────┘

移位的好处:使相邻窗口之间产生交叉连接,促进信息流动。

3D移位的特殊考虑

对于视频,移位需要在三个维度进行:

class WindowAttention3D(nn.Module):
    """
    3D移位窗口注意力
    """
    def __init__(self, window_size=(2, 7, 7)):
        self.window_size = window_size  # (P_t, P_h, P_w)
    
    def forward(self, x, shift_size):
        """
        x: [B, T, H, W, C]
        shift_size: 移位量 (Δt, Δh, Δw)
        """
        B, T, H, W, C = x.shape
        P_t, P_h, P_w = self.window_size
        
        # 1. 循环移位
        x = torch.roll(x, shifts=shift_size, dims=(1, 2, 3))
        
        # 2. 划分窗口
        x = window_partition_3d(x, window_size)
        # x: [num_windows * B, P_t * P_h * P_w, C]
        
        # 3. 窗口内注意力
        # ... 注意力计算 ...
        
        # 4. 逆窗口划分
        x = window_reverse_3d(x, window_size, T, H, W)
        
        # 5. 逆移位 (移位方向相反)
        x = torch.roll(x, shifts=[-s for s in shift_size], dims=(1, 2, 3))
        
        return x

循环移位与掩码

Video Swin Transformer使用**循环移位(Cyclic Shift)**实现移位窗口:

原始窗口划分 (P=4):
┌───┬───┐
│ 0 │ 1 │    每个区域代表 P/2 的偏移
├───┼───┤
│ 2 │ 3 │
└───┴───┘

循环移位后 (Δ=P/2):
┌───┬───┐
│ 3 │ 2 │    通过循环移位,非相邻区域被移到一起
├───┼───┤
│ 1 │ 0 │    然后应用掩码使注意力只发生在正确区域
└───┴───┘

掩码机制:使用注意力掩码避免不相关的patch之间产生注意力。

# 注意力掩码示例
# 不同区域之间应该被掩蔽
mask = torch.zeros(num_windows, num_windows)
mask[region_0, region_1] = -100  # 不应该交互的区域
mask[region_1, region_0] = -100
mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, num_windows, num_windows]

整体架构

架构概览

Input Video: [B, 3, T, H, W] = [B, 3, 32, 224, 224]
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Video to Tokens (3D Conv)                        │
│  Conv3d(3, 96, kernel_size=(2,4,4), stride=(2,4,4))│
│  输出: [B, 96, T/2, H/4, W/4]                    │
└───────────────────────────────────────────────────┘
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Stage 1: 3D Swin Blocks (×2)                    │
│  - Window Attention (P_t=2, P_h=P_w=7)           │
│  - 移位窗口注意力                                 │
│  Patch Merging                                    │
│  输出: [B, 192, T/4, H/16, W/16]                │
└───────────────────────────────────────────────────┘
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Stage 2: 3D Swin Blocks (×2)                    │
│  输出: [B, 384, T/8, H/32, W/32]                 │
└───────────────────────────────────────────────────┘
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Stage 3: 3D Swin Blocks (×6)                    │
│  输出: [B, 768, T/16, H/64, W/64]               │
└───────────────────────────────────────────────────┘
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Stage 4: 3D Swin Blocks (×2)                    │
│  输出: [B, 768, T/32, H/128, W/128]              │
└───────────────────────────────────────────────────┘
        │
        ▼
┌───────────────────────────────────────────────────┐
│  Temporal Mean Pooling + Classification Head      │
│  输出: [B, num_classes]                          │
└───────────────────────────────────────────────────┘

各Stage配置

Stage输出尺寸窗口大小Block数通道维度
Token Embed--96
Stage 12192
Stage 22384
Stage 36768
Stage 42768

3D Patch Merging

Patch Merging在时间维度上也进行下采样:

class PatchMerging3D(nn.Module):
    """
    3D Patch Merging:空间和时间同时下采样
    """
    def __init__(self, dim):
        super().__init__()
        self.reduction = nn.Conv3d(
            dim, dim * 2,
            kernel_size=(2, 2, 2),  # 时间×空间各下采样2倍
            stride=(2, 2, 2)
        )
        self.norm = nn.LayerNorm(dim * 2)
    
    def forward(self, x):
        # x: [B, T, H, W, C]
        B, T, H, W, C = x.shape
        
        # 重组: [B, T/2, H/2, W/2, 8*C]
        x = x.reshape(B, T//2, 2, H//2, 2, W//2, 2, C)
        x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
        x = x.reshape(B, T//2, H//2, W//2, 8*C)
        
        x = self.reduction(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
        x = self.norm(x)
        
        return x  # [B, T/2, H/2, W/2, 2*C]

相对位置偏置

3D相对位置偏置

Video Swin Transformer扩展了2D相对位置偏置到3D:

其中 是可学习的3D相对位置偏置矩阵。

3D相对位置偏置参数

  • 空间维度:
  • 时间维度:
  • 总计: 个参数
class WindowAttention3D(nn.Module):
    def __init__(self, window_size=(2, 7, 7), num_heads=8):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        
        # 3D相对位置偏置表
        # (2*2-1) * (2*7-1) * (2*7-1) = 3 * 13 * 13 = 507
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1) * 
                       (2*window_size[1]-1) * 
                       (2*window_size[2]-1),
                       num_heads)
        )
        
        # 计算相对位置索引
        self.register_buffer(
            "relative_position_index",
            self._get_relative_position_index()
        )
    
    def _get_relative_position_index(self):
        """计算3D相对位置索引"""
        coords_t = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        
        # 生成网格
        coords = torch.stack(
            torch.meshgrid(coords_t, coords_h, coords_w, indexing='ij')
        )  # [3, P_t, P_h, P_w]
        
        # 重组为序列
        coords_flat = coords.reshape(3, -1)  # [3, P_t*P_h*P_w]
        
        # 计算相对位置
        relative_pos = coords_flat[:, :, None] - coords_flat[:, None, :]
        # [3, N, N]
        
        # 转换为偏置表索引
        relative_pos = relative_pos.permute(1, 2, 0).contiguous()
        relative_pos[:, :, 0] += self.window_size[0] - 1
        relative_pos[:, :, 1] += self.window_size[1] - 1
        relative_pos[:, :, 2] += self.window_size[2] - 1
        
        relative_pos_index = (
            relative_pos[:, :, 0] * 
            (2*self.window_size[1]-1) * 
            (2*self.window_size[2]-1) +
            relative_pos[:, :, 1] * 
            (2*self.window_size[2]-1) +
            relative_pos[:, :, 2]
        )
        
        return relative_pos_index
    
    def forward(self, x):
        # ...
        # 获取相对位置偏置
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(
            self.window_size[0] * self.window_size[1] * self.window_size[2],
            self.window_size[0] * self.window_size[1] * self.window_size[2],
            -1
        )  # [N, N, num_heads]
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        
        # 添加偏置到注意力分数
        attn = attn + relative_position_bias.unsqueeze(0)
        # ...

与TimeSformer的对比

方法论对比

方面TimeSformerVideo Swin Transformer
注意力机制分解放置注意力3D移位窗口注意力
窗口设计无窗口,全局计算固定大小窗口
窗口间连接无需特殊处理移位机制连接
层次化结构单尺度4级层次化
多尺度建模有(通过Patch Merging)
位置编码分离时空编码3D相对位置偏置

计算复杂度对比

模型K400 Top-1FLOPs推理速度
I3D-R5076.3%108×G
SlowFast-R5078.8%65×G1.5×
TimeSformer-B82.7%2380×G0.1×
Video Swin-B84.9%321×G0.4×

结论:Video Swin Transformer在保持较高精度的同时,显著降低了计算量。

表达能力对比

Video Swin Transformer的优势

  1. 层次化表示:多尺度特征有利于下游任务(检测、分割)
  2. 局部-全局建模:窗口内局部 + 跨窗口全局
  3. 平移不变性:相对位置偏置提供更好的泛化

TimeSformer的优势

  1. 全局感受野:单层即可建模全视频依赖
  2. 实现简单:无需处理窗口移位和掩码
  3. 更适合分类任务:全局注意力直接聚合信息

实验结果

Kinetics-400

模型Top-1Top-5参数量
TSN72.4%90.4%-
I3D74.3%91.4%25.0M
Non-Local I3D77.7%93.3%35.7M
SlowFast78.8%93.5%34.0M
TimeSformer82.7%95.4%121.4M
Video Swin-S83.6%95.9%49.6M
Video Swin-B84.9%96.4%88.8M

Kinetics-600

模型Top-1Top-5
SlowFast81.8%95.1%
TimeSformer85.0%96.4%
Video Swin-B86.4%97.2%

消融实验

窗口大小的影响

窗口大小 K400 Top-1
84.9%
85.1%
85.0%

移位窗口的效果

配置K400 Top-1
仅规则窗口83.2%
规则+移位窗口84.9%

相对位置偏置的效果

配置K400 Top-1
无相对位置偏置83.5%
2D相对位置偏置84.2%
3D相对位置偏置84.9%

下游任务应用

时序动作检测

Video Swin Transformer可以扩展用于时序动作检测任务:

  1. 特征提取:使用Video Swin作为视频编码器
  2. 时序建模:添加时序检测头
  3. 边界回归:预测动作边界

视频分割

层次化特征金字塔有利于密集预测任务:

  1. 多尺度特征:Stage 2-4提供不同分辨率特征
  2. 细粒度分割:浅层特征保留空间细节
  3. 语义分割:深层特征捕获高级语义

总结

Video Swin Transformer成功地将Swin Transformer的层次化设计和移位窗口注意力扩展到视频领域,其主要贡献包括:

  1. 3D移位窗口注意力:高效建模时空依赖
  2. 时空层次化结构:多尺度特征表示
  3. 3D相对位置偏置:增强位置感知能力
  4. 优秀的性能-效率权衡:在保持高精度的同时降低计算量

该工作为视频理解领域提供了一种高效且强大的架构选择。


参考文献

Footnotes

  1. Liu Z, Ning J, Cao Y, et al. Video swin transformer[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 3202-3211.