引言
Video Swin Transformer1 是微软亚洲研究院将Swin Transformer扩展到视频领域的工作,提出**3D移位窗口注意力(3D Shifted Window Attention)**机制,在视频理解和动作识别任务上取得了显著进展。该工作由Swin Transformer的原班人马完成,保持了Swin Transformer的层次化设计和高效注意力计算,同时将建模能力扩展到时间维度。
本文深入解析Video Swin Transformer的核心设计,包括3D移位窗口注意力的数学原理、层次化架构、与其他视频Transformer的对比,以及实验结果分析。
Video Swin Transformer概述
论文信息
- 标题:《Video Swin Transformer》
- 作者:Ze Liu, Jia Ning, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin, Han Hu
- 机构:Microsoft Research Asia
- 会议:CVPR 2022
核心思想
Video Swin Transformer的核心贡献是将Swin Transformer的2D移位窗口注意力扩展到3D时空域,同时保持层次化结构设计:
Swin Transformer (2D):
- 移位窗口注意力 (Shifted Window Attention)
- 层次化特征金字塔
- 高效的局部注意力计算
Video Swin Transformer (3D):
- 3D移位窗口注意力 (Shifted Window Attention in 3D)
- 时空层次化特征金字塔
- 联合建模时空依赖
与TimeSformer的关键区别
| 方面 | TimeSformer | Video Swin Transformer |
|---|---|---|
| 注意力类型 | 分解放置注意力 | 3D移位窗口注意力 |
| 窗口设计 | 全局稀疏 | 局部窗口 + 移位 |
| 计算复杂度 | ||
| 层次化 | 无 | 有(4个stage) |
| 多尺度 | 无 | 有 |
3D移位窗口注意力
2D移位窗口注意力回顾
Swin Transformer的2D移位窗口注意力将特征图划分为不重叠的窗口,每个窗口内独立计算注意力:
标准窗口注意力:
┌─────┬─────┬─────┐
│ W11 │ W12 │ W13 │
├─────┼─────┼─────┤
│ W21 │ W22 │ W23 │ 每个窗口包含 M×M 个patches
├─────┼─────┼─────┤
│ W31 │ W32 │ W33 │
└─────┴─────┴─────┘
注意力计算: 仅在窗口内进行
W11注意力 → W11内
W12注意力 → W12内
...
3D窗口划分
Video Swin Transformer将3D视频划分为3D窗口:
每个窗口 包含 个时空位置(patches),其中:
- :时间维度的窗口大小
- :高度维度的窗口大小
- :宽度维度的窗口大小
典型配置:,即每个窗口包含 个时空patches。
3D窗口内注意力
对于窗口 ,窗口内注意力定义为:
其中 。
窗口移位机制
与Swin Transformer相同,Video Swin Transformer采用交替使用规则窗口和移位窗口的策略:
Stage 1: 规则窗口注意力 (W-MSA)
┌─────┬─────┬─────┐
│ W11 │ W12 │ W13 │
├─────┼─────┼─────┤
│ W21 │ W22 │ W23 │
├─────┼─────┼─────┤
│ W31 │ W32 │ W33 │
└─────┴─────┴─────┘
Stage 2: 移位窗口注意力 (SW-MSA)
┌─────┬─────┬─────┤ ← 移位偏移量 (⌊P/2⌋, ⌊P/2⌋)
│SW11 │SW12 │SW13│
├─────┼─────┼─────┤
│SW21 │SW22 │SW23│
├─────┼─────┼─────┤
│SW31 │SW32 │SW33│
└─────┴─────┴─────┘
移位的好处:使相邻窗口之间产生交叉连接,促进信息流动。
3D移位的特殊考虑
对于视频,移位需要在三个维度进行:
class WindowAttention3D(nn.Module):
"""
3D移位窗口注意力
"""
def __init__(self, window_size=(2, 7, 7)):
self.window_size = window_size # (P_t, P_h, P_w)
def forward(self, x, shift_size):
"""
x: [B, T, H, W, C]
shift_size: 移位量 (Δt, Δh, Δw)
"""
B, T, H, W, C = x.shape
P_t, P_h, P_w = self.window_size
# 1. 循环移位
x = torch.roll(x, shifts=shift_size, dims=(1, 2, 3))
# 2. 划分窗口
x = window_partition_3d(x, window_size)
# x: [num_windows * B, P_t * P_h * P_w, C]
# 3. 窗口内注意力
# ... 注意力计算 ...
# 4. 逆窗口划分
x = window_reverse_3d(x, window_size, T, H, W)
# 5. 逆移位 (移位方向相反)
x = torch.roll(x, shifts=[-s for s in shift_size], dims=(1, 2, 3))
return x循环移位与掩码
Video Swin Transformer使用**循环移位(Cyclic Shift)**实现移位窗口:
原始窗口划分 (P=4):
┌───┬───┐
│ 0 │ 1 │ 每个区域代表 P/2 的偏移
├───┼───┤
│ 2 │ 3 │
└───┴───┘
循环移位后 (Δ=P/2):
┌───┬───┐
│ 3 │ 2 │ 通过循环移位,非相邻区域被移到一起
├───┼───┤
│ 1 │ 0 │ 然后应用掩码使注意力只发生在正确区域
└───┴───┘
掩码机制:使用注意力掩码避免不相关的patch之间产生注意力。
# 注意力掩码示例
# 不同区域之间应该被掩蔽
mask = torch.zeros(num_windows, num_windows)
mask[region_0, region_1] = -100 # 不应该交互的区域
mask[region_1, region_0] = -100
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, num_windows, num_windows]整体架构
架构概览
Input Video: [B, 3, T, H, W] = [B, 3, 32, 224, 224]
│
▼
┌───────────────────────────────────────────────────┐
│ Video to Tokens (3D Conv) │
│ Conv3d(3, 96, kernel_size=(2,4,4), stride=(2,4,4))│
│ 输出: [B, 96, T/2, H/4, W/4] │
└───────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────┐
│ Stage 1: 3D Swin Blocks (×2) │
│ - Window Attention (P_t=2, P_h=P_w=7) │
│ - 移位窗口注意力 │
│ Patch Merging │
│ 输出: [B, 192, T/4, H/16, W/16] │
└───────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────┐
│ Stage 2: 3D Swin Blocks (×2) │
│ 输出: [B, 384, T/8, H/32, W/32] │
└───────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────┐
│ Stage 3: 3D Swin Blocks (×6) │
│ 输出: [B, 768, T/16, H/64, W/64] │
└───────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────┐
│ Stage 4: 3D Swin Blocks (×2) │
│ 输出: [B, 768, T/32, H/128, W/128] │
└───────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────┐
│ Temporal Mean Pooling + Classification Head │
│ 输出: [B, num_classes] │
└───────────────────────────────────────────────────┘
各Stage配置
| Stage | 输出尺寸 | 窗口大小 | Block数 | 通道维度 |
|---|---|---|---|---|
| Token Embed | - | - | 96 | |
| Stage 1 | 2 | 192 | ||
| Stage 2 | 2 | 384 | ||
| Stage 3 | 6 | 768 | ||
| Stage 4 | 2 | 768 |
3D Patch Merging
Patch Merging在时间维度上也进行下采样:
class PatchMerging3D(nn.Module):
"""
3D Patch Merging:空间和时间同时下采样
"""
def __init__(self, dim):
super().__init__()
self.reduction = nn.Conv3d(
dim, dim * 2,
kernel_size=(2, 2, 2), # 时间×空间各下采样2倍
stride=(2, 2, 2)
)
self.norm = nn.LayerNorm(dim * 2)
def forward(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# 重组: [B, T/2, H/2, W/2, 8*C]
x = x.reshape(B, T//2, 2, H//2, 2, W//2, 2, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
x = x.reshape(B, T//2, H//2, W//2, 8*C)
x = self.reduction(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
x = self.norm(x)
return x # [B, T/2, H/2, W/2, 2*C]相对位置偏置
3D相对位置偏置
Video Swin Transformer扩展了2D相对位置偏置到3D:
其中 是可学习的3D相对位置偏置矩阵。
3D相对位置偏置参数:
- 空间维度:
- 时间维度:
- 总计: 个参数
class WindowAttention3D(nn.Module):
def __init__(self, window_size=(2, 7, 7), num_heads=8):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
# 3D相对位置偏置表
# (2*2-1) * (2*7-1) * (2*7-1) = 3 * 13 * 13 = 507
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1) *
(2*window_size[1]-1) *
(2*window_size[2]-1),
num_heads)
)
# 计算相对位置索引
self.register_buffer(
"relative_position_index",
self._get_relative_position_index()
)
def _get_relative_position_index(self):
"""计算3D相对位置索引"""
coords_t = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
# 生成网格
coords = torch.stack(
torch.meshgrid(coords_t, coords_h, coords_w, indexing='ij')
) # [3, P_t, P_h, P_w]
# 重组为序列
coords_flat = coords.reshape(3, -1) # [3, P_t*P_h*P_w]
# 计算相对位置
relative_pos = coords_flat[:, :, None] - coords_flat[:, None, :]
# [3, N, N]
# 转换为偏置表索引
relative_pos = relative_pos.permute(1, 2, 0).contiguous()
relative_pos[:, :, 0] += self.window_size[0] - 1
relative_pos[:, :, 1] += self.window_size[1] - 1
relative_pos[:, :, 2] += self.window_size[2] - 1
relative_pos_index = (
relative_pos[:, :, 0] *
(2*self.window_size[1]-1) *
(2*self.window_size[2]-1) +
relative_pos[:, :, 1] *
(2*self.window_size[2]-1) +
relative_pos[:, :, 2]
)
return relative_pos_index
def forward(self, x):
# ...
# 获取相对位置偏置
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2],
-1
) # [N, N, num_heads]
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
# 添加偏置到注意力分数
attn = attn + relative_position_bias.unsqueeze(0)
# ...与TimeSformer的对比
方法论对比
| 方面 | TimeSformer | Video Swin Transformer |
|---|---|---|
| 注意力机制 | 分解放置注意力 | 3D移位窗口注意力 |
| 窗口设计 | 无窗口,全局计算 | 固定大小窗口 |
| 窗口间连接 | 无需特殊处理 | 移位机制连接 |
| 层次化结构 | 单尺度 | 4级层次化 |
| 多尺度建模 | 无 | 有(通过Patch Merging) |
| 位置编码 | 分离时空编码 | 3D相对位置偏置 |
计算复杂度对比
| 模型 | K400 Top-1 | FLOPs | 推理速度 |
|---|---|---|---|
| I3D-R50 | 76.3% | 108×G | 1× |
| SlowFast-R50 | 78.8% | 65×G | 1.5× |
| TimeSformer-B | 82.7% | 2380×G | 0.1× |
| Video Swin-B | 84.9% | 321×G | 0.4× |
结论:Video Swin Transformer在保持较高精度的同时,显著降低了计算量。
表达能力对比
Video Swin Transformer的优势:
- 层次化表示:多尺度特征有利于下游任务(检测、分割)
- 局部-全局建模:窗口内局部 + 跨窗口全局
- 平移不变性:相对位置偏置提供更好的泛化
TimeSformer的优势:
- 全局感受野:单层即可建模全视频依赖
- 实现简单:无需处理窗口移位和掩码
- 更适合分类任务:全局注意力直接聚合信息
实验结果
Kinetics-400
| 模型 | Top-1 | Top-5 | 参数量 |
|---|---|---|---|
| TSN | 72.4% | 90.4% | - |
| I3D | 74.3% | 91.4% | 25.0M |
| Non-Local I3D | 77.7% | 93.3% | 35.7M |
| SlowFast | 78.8% | 93.5% | 34.0M |
| TimeSformer | 82.7% | 95.4% | 121.4M |
| Video Swin-S | 83.6% | 95.9% | 49.6M |
| Video Swin-B | 84.9% | 96.4% | 88.8M |
Kinetics-600
| 模型 | Top-1 | Top-5 |
|---|---|---|
| SlowFast | 81.8% | 95.1% |
| TimeSformer | 85.0% | 96.4% |
| Video Swin-B | 86.4% | 97.2% |
消融实验
窗口大小的影响:
| 窗口大小 | K400 Top-1 |
|---|---|
| 84.9% | |
| 85.1% | |
| 85.0% |
移位窗口的效果:
| 配置 | K400 Top-1 |
|---|---|
| 仅规则窗口 | 83.2% |
| 规则+移位窗口 | 84.9% |
相对位置偏置的效果:
| 配置 | K400 Top-1 |
|---|---|
| 无相对位置偏置 | 83.5% |
| 2D相对位置偏置 | 84.2% |
| 3D相对位置偏置 | 84.9% |
下游任务应用
时序动作检测
Video Swin Transformer可以扩展用于时序动作检测任务:
- 特征提取:使用Video Swin作为视频编码器
- 时序建模:添加时序检测头
- 边界回归:预测动作边界
视频分割
层次化特征金字塔有利于密集预测任务:
- 多尺度特征:Stage 2-4提供不同分辨率特征
- 细粒度分割:浅层特征保留空间细节
- 语义分割:深层特征捕获高级语义
总结
Video Swin Transformer成功地将Swin Transformer的层次化设计和移位窗口注意力扩展到视频领域,其主要贡献包括:
- 3D移位窗口注意力:高效建模时空依赖
- 时空层次化结构:多尺度特征表示
- 3D相对位置偏置:增强位置感知能力
- 优秀的性能-效率权衡:在保持高精度的同时降低计算量
该工作为视频理解领域提供了一种高效且强大的架构选择。
参考文献
Footnotes
-
Liu Z, Ning J, Cao Y, et al. Video swin transformer[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 3202-3211. ↩