引言
现实世界中的视频往往持续数分钟甚至数小时,例如电影、体育赛事、监控录像等。**长视频理解(Long Video Understanding)**是视频分析领域最具挑战性的问题之一,需要在大量帧中捕捉长程时空依赖关系。
与短视频(通常8-32帧)不同,长视频带来了独特的挑战:计算复杂度呈线性或超线性增长、内存消耗巨大、长程依赖建模困难。本文系统介绍长视频理解的核心问题与解决方案。
长视频理解的挑战
计算复杂度问题
对于长度为 帧的视频,标准时空注意力的复杂度为:
其中 是每帧的patch数量。
问题演示:
| 视频长度 | 帧数 | 假设 | Token数 | 联合注意力 FLOPs |
|---|---|---|---|---|
| 短视频 | 8 | 196 | 1,568 | ~2.5M |
| 中视频 | 32 | 196 | 6,272 | ~39M |
| 长视频 (1min@30fps) | 1800 | 196 | 352,800 | ~124B |
| 长视频 (10min@30fps) | 18000 | 196 | 3,528,000 | ~12.4T |
即使是8帧的视频,全序列注意力也难以处理;30秒以上的视频几乎不可能用标准注意力处理。
内存限制
Transformer的注意力计算需要存储完整的注意力矩阵:
对于 :
- 矩阵大小: 元素
- 内存占用(float32):约 157MB
- 对于 :约 40GB
这远远超出了GPU显存限制。
长程依赖建模困难
长视频中存在多种长程依赖:
- 因果依赖:视频开头的事件可能影响结尾
- 动作完整性:完整动作可能跨越数百帧
- 时序一致性:同一物体在不同时间点应保持一致
- 全局上下文:整体场景理解需要全视频信息
解决方案概览
长视频理解的主要技术路线:
| 方法类型 | 核心思想 | 代表工作 |
|---|---|---|
| Memory-based | 使用外部记忆存储中间信息 | Memory Transformer, Memformer |
| Sparse Attention | 稀疏连接模式降低复杂度 | Longformer, BigBird |
| Hierarchical | 分层建模聚合全局信息 | Video Swin, 时序金字塔 |
| Segment-based | 分段处理后聚合 | Clip-based, Segment-based |
| Compressive | 压缩信息减少token | Token Merging, Token Learnt |
Memory-based 方法
核心思想
**外部记忆(External Memory)**机制允许模型在处理长序列时”记住”之前的信息,而不是将所有历史都保存在激活中。
标准 Transformer:
┌──────────────────────────────────────┐
│ 输入序列: [x₁, x₂, ..., x_T] │
│ 注意力: 直接访问所有历史位置 │
└──────────────────────────────────────┘
Memory Transformer:
┌──────────────────────────────────────┐
│ 输入序列: [x₁, x₂, ..., x_T] │
│ 外部记忆: [m₁, m₂, ..., m_K] │
│ 注意力: 输入 ←→ 记忆 + 输入 ←→ 输入 │
└──────────────────────────────────────┘
Memory Transformer
Memory Transformer1 在标准Transformer中引入外部记忆模块:
class MemoryTransformer(nn.Module):
"""
Memory Transformer: 带外部记忆的Transformer
"""
def __init__(self, d_model, n_heads, memory_size, memory_dim):
super().__init__()
self.memory_size = memory_size
# 记忆矩阵 (可学习或动态更新)
self.memory = nn.Parameter(
torch.randn(memory_size, memory_dim)
)
# 标准注意力
self.self_attn = nn.MultiheadAttention(d_model, n_heads)
# 记忆注意力 (输入查询记忆)
self.memory_attn = nn.MultiheadAttention(d_model, n_heads)
def forward(self, x):
# x: [L, B, D] - 输入序列
# 1. 自注意力 (标准)
attn_out, _ = self.self_attn(x, x, x)
# 2. 记忆注意力 (查询记忆)
# Query来自输入, Key/Value来自记忆
mem_out, _ = self.memory_attn(
query=x,
key=self.memory.expand(-1, x.size(1), -1),
value=self.memory.expand(-1, x.size(1), -1)
)
return attn_out + mem_out动态记忆更新
更先进的方法采用动态记忆更新策略:
class DynamicMemory(nn.Module):
"""
动态记忆模块: 随时间更新记忆
"""
def __init__(self, memory_size, d_model):
super().__init__()
self.memory_size = memory_size
# 记忆状态
self.memory = torch.zeros(memory_size, d_model)
# 记忆更新门控
self.update_gate = nn.Linear(d_model * 2, d_model)
self.memory_proj = nn.Linear(d_model, d_model)
def update(self, hidden_states, new_info):
"""
更新记忆
hidden_states: 当前隐藏状态
new_info: 从输入中提取的新信息
"""
# 聚合当前信息
aggregated = hidden_states.mean(dim=0) # [B, D]
# 更新门控
combined = torch.cat([self.memory, aggregated], dim=-1)
gate = torch.sigmoid(self.update_gate(combined))
# 选择性更新
self.memory = (1 - gate) * self.memory + gate * self.memory_proj(aggregated)
return self.memoryMemformer
Memformer2 是一种专门为长序列设计的高效Transformer:
- Memory-augmented Architecture:引入可读写的外部记忆
- O(NM)复杂度:其中是序列长度,是记忆大小
- 分层记忆:多层记忆捕获不同粒度的信息
Sparse Attention 稀疏注意力
核心思想
稀疏注意力通过限制每个token的连接模式,在保持部分全局建模能力的同时大幅降低复杂度。
标准全注意力:
┌─────────────┐
│ ● ● ● ● ● │ 每个位置与所有位置连接
│ ● ● ● ● ● │ O(N²) 复杂度
│ ● ● ● ● ● │
│ ● ● ● ● ● │
│ ● ● ● ● ● │
└─────────────┘
稀疏注意力:
┌─────────────┐
│ ● ● ● ● │ 局部连接 + 稀疏全局连接
│ ● ● ● │ O(N·k) 复杂度 (k为邻居数)
│ ● ● ● ● │
│ ● ● ● ● │
│ ● ● ● ● │
└─────────────┘
滑动窗口注意力
**局部注意力(Local Attention)**是最简单的稀疏模式:
class SlidingWindowAttention(nn.Module):
"""
滑动窗口注意力: 每个位置只与窗口内的邻居交互
"""
def __init__(self, window_size):
self.window_size = window_size
def forward(self, q, k, v):
# q, k, v: [B, T, N, D]
B, T, N, D = q.shape
# 计算注意力分数
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
# 创建因果掩码 (可选)
mask = torch.triu(
torch.ones(T, T, device=q.device),
diagonal=1
).bool()
# 创建窗口掩码
# 距离超过window_size的位置掩蔽
relative_positions = torch.arange(T, device=q.device).unsqueeze(0) - \
torch.arange(T, device=q.device).unsqueeze(1)
window_mask = torch.abs(relative_positions) > self.window_size
# 组合掩码
final_mask = mask | window_mask
attn = attn.masked_fill(final_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# Softmax
attn = F.softmax(attn, dim=-1)
return torch.matmul(attn, v)Longformer的稀疏模式
Longformer3 提出三种稀疏注意力模式:
-
局部窗口注意力(Local Attention)
- 窗口大小
- 复杂度
-
全局注意力(Global Attention)
- 特殊token(如[CLS])与所有位置交互
- 用于聚合全局信息
-
随机注意力(Random Attention)
- 每个位置随机连接个位置
- 提供全局信息流动
class LongformerAttention(nn.Module):
"""
Longformer稀疏注意力实现
"""
def __init__(self, d_model, n_heads, window_size=512, num_random=256):
super().__init__()
self.window_size = window_size
self.num_random = num_random
# 全局token索引 (需要全局注意力的位置)
self.global_tokens = [0] # [CLS] token
self.attention = nn.MultiheadAttention(d_model, n_heads)
def forward(self, x):
B, T, N, D = x.shape
# 构建稀疏注意力模式
attention_mode = self._get_attention_mode(T)
# 根据模式计算注意力
if attention_mode == 'local':
# 滑动窗口注意力
attn = self._local_attention(x)
elif attention_mode == 'global':
# 全局注意力
attn = self._global_attention(x)
elif attention_mode == 'random':
# 随机注意力
attn = self._random_attention(x)
return attnBigBird的稀疏模式
BigBird4 提出类似Longformer的设计,但添加了更多组件:
BigBird稀疏模式:
┌─────────────────────────┐
│ ● ● ● ● ● ● ● │ w: 窗口注意力
│ ● ● ● ● ● ● ● │ g: 全局注意力
│ ● ● ● ● ● ● ● │ r: 随机注意力
│ ● ● ● ● ● ● ● │
│ ● ● ● ● ● ● ● │
└─────────────────────────┘
复杂度分析:
- 全注意力:
- BigBird:
Hierarchical Modeling 分层建模
核心思想
**分层建模(Hierarchical Modeling)**将长视频分解为多个层次,逐步聚合信息:
视频 → 片段 → Clip → 视频级表示
↓ ↓ ↓ ↓
帧级 片段级 Clip级 视频级
时序金字塔
class TemporalPyramid(nn.Module):
"""
时序金字塔: 多尺度时序建模
"""
def __init__(self, d_model, pyramid_levels=[1, 2, 4]):
super().__init__()
self.pyramid_levels = pyramid_levels
# 不同尺度的聚合器
self.aggregators = nn.ModuleDict({
str(level): TemporalAggregator(d_model, level)
for level in pyramid_levels
})
# 融合层
self.fusion = nn.Linear(d_model * len(pyramid_levels), d_model)
def forward(self, x):
"""
x: [B, T, N, D] - 时空特征
"""
multi_scale_features = []
for level in self.pyramid_levels:
# 按level分组
pooled = self.aggregators[str(level)](x, level)
multi_scale_features.append(pooled)
# 多尺度特征融合
fused = torch.cat(multi_scale_features, dim=-1)
return self.fusion(fused)
def forward_with_grouping(self, x, level):
"""
按level对时间维度分组聚合
"""
B, T, N, D = x.shape
# 确保T能被level整除
T_padded = ((T + level - 1) // level) * level
if T_padded > T:
x = F.pad(x, (0, 0, 0, 0, 0, T_padded - T))
# 重塑: [B, T//level, level, N, D]
x = x.view(B, T_padded // level, level, N, D)
# 聚合: [B, T//level, N, D]
return x.mean(dim=2)分段处理与聚合
class SegmentBasedVideoModel(nn.Module):
"""
基于分段的视频模型: 分段处理后聚合
"""
def __init__(self, segment_encoder, segment_aggregator):
super().__init__()
self.segment_encoder = segment_encoder # 处理单个片段
self.segment_aggregator = segment_aggregator # 聚合片段特征
def forward(self, video, num_segments=8):
"""
video: [B, T, C, H, W]
"""
B, T, C, H, W = video.shape
# 均匀分段
segment_length = T // num_segments
segments = video.reshape(B, num_segments, segment_length, C, H, W)
# 逐段编码
segment_features = []
for i in range(num_segments):
feat = self.segment_encoder(segments[:, i])
segment_features.append(feat)
# Stack: [B, num_segments, D]
segment_features = torch.stack(segment_features, dim=1)
# 聚合为视频级表示
video_repr = self.segment_aggregator(segment_features)
return video_repr滑动窗口方法
class SlidingWindowVideoModel(nn.Module):
"""
滑动窗口视频模型: 窗口滑动 + 聚合
"""
def __init__(self, encoder, aggregator, window_size=16, stride=8):
super().__init__()
self.window_size = window_size
self.stride = stride
self.encoder = encoder
self.aggregator = aggregator
def forward(self, video):
"""
video: [B, T, C, H, W]
"""
B, T, C, H, W = video.shape
window_features = []
# 滑动窗口
for start in range(0, T - self.window_size + 1, self.stride):
end = start + self.window_size
window = video[:, start:end]
# 编码窗口
feat = self.encoder(window)
window_features.append(feat)
# 处理尾部 (如果视频长度不是stride的整数倍)
if (T - self.window_size) % self.stride != 0:
window = video[:, -self.window_size:]
feat = self.encoder(window)
window_features.append(feat)
# Stack: [B, num_windows, D]
window_features = torch.stack(window_features, dim=1)
# 聚合窗口特征
return self.aggregator(window_features)Token 稀疏化
空间Token稀疏化
减少每帧的token数量:
class TokenMerging(nn.Module):
"""
Token合并: 减少空间token数量
"""
def forward(self, x, merge_ratio=2):
# x: [B, T, N, D]
B, T, N, D = x.shape
# 重塑为网格
H = W = int(math.sqrt(N))
x = x.view(B, T, H, W, D)
# 合并相邻token
x = x.reshape(B, T, H // merge_ratio, merge_ratio,
W // merge_ratio, merge_ratio, D)
x = x.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
x = x.reshape(B, T, H // merge_ratio,
W // merge_ratio, merge_ratio * merge_ratio * D)
# 线性投影恢复维度
return self.proj(x)时间Token稀疏化
减少时间维度的token:
class TemporalTokenSampling(nn.Module):
"""
时间Token采样: 动态选择重要帧
"""
def __init__(self, d_model):
super().__init__()
self.importance_scorer = nn.Linear(d_model, 1)
def forward(self, x, num_samples=None):
# x: [B, T, N, D]
B, T, N, D = x.shape
# 计算每帧的重要性分数
# 聚合空间维度
x_pooled = x.mean(dim=2) # [B, T, D]
importance = self.importance_scorer(x_pooled).squeeze(-1) # [B, T]
# 软采样 (使用注意力权重)
if num_samples is not None:
weights = F.softmax(importance, dim=1)
# 展开为权重分布
weights = weights.unsqueeze(-1).unsqueeze(-1) # [B, T, 1, 1]
sampled = (x * weights).sum(dim=1) # [B, N, D]
return sampled, weights
else:
return importance位置编码扩展
相对位置编码
对于长视频,相对位置编码比绝对位置编码更适合:
class ExtendedRelativePosition(nn.Module):
"""
扩展相对位置编码: 支持长序列
"""
def __init__(self, max_len, num_heads):
super().__init__()
self.max_len = max_len
self.num_heads = num_heads
# 相对位置偏置表
# 使用对数间隔,支持更长距离
self.relative_attention_bias = nn.Parameter(
torch.zeros(2 * max_len - 1, num_heads)
)
# 可学习的距离衰减
self.distance_decay = nn.Parameter(torch.ones(1))
def forward(self, seq_len):
# 生成相对位置索引
position_ids = torch.arange(seq_len)
relative_pos = position_ids.unsqueeze(0) - position_ids.unsqueeze(1)
relative_pos = relative_pos + self.max_len - 1
# 对数距离变换
log_distance = torch.log(1 + torch.abs(relative_pos.float()))
scaled_pos = (log_distance * self.distance_decay).long()
scaled_pos = torch.clamp(scaled_pos, 0, 2 * self.max_len - 2)
# 获取偏置
relative_bias = self.relative_attention_bias[scaled_pos]
return relative_bias.permute(2, 0, 1) # [num_heads, seq_len, seq_len]方法对比与选择
复杂度对比
| 方法 | 时间复杂度 | 空间复杂度 | 长程建模 | 实现复杂度 |
|---|---|---|---|---|
| 标准注意力 | 最强 | 低 | ||
| 滑动窗口 | 弱 | 低 | ||
| Memory-based | 中等 | 中 | ||
| 分层建模 | 强 | 中 | ||
| 稀疏注意力 | 中等 | 中 |
其中 是窗口大小, 是记忆大小, 是局部窗口, 是随机连接数。
选择指南
| 视频长度 | 推荐方法 | 理由 |
|---|---|---|
| < 1分钟 | 标准Transformer | 足够处理 |
| 1-10分钟 | 滑动窗口 + 分层 | 简单有效 |
| 10-60分钟 | Memory-based | 平衡效率和建模 |
| > 1小时 | 稀疏注意力 + 分层 | 必须稀疏化 |
总结
长视频理解是视频分析领域的重要挑战,本文介绍了四种主要的解决方案:
- Memory-based方法:通过外部记忆存储和检索信息,降低激活内存
- Sparse Attention:通过稀疏连接模式,将复杂度降为
- Hierarchical Modeling:分层聚合多尺度信息,捕获不同粒度的依赖
- Token稀疏化:减少输入token数量,降低总体计算量
这些方法可以组合使用,例如在分层建模的每一层使用滑动窗口注意力,以进一步降低计算复杂度。
参考文献
Footnotes
-
Wu K, Peng H, Zhou Z, et al. Memformer: A memory-augmented transformer for efficient long-range sequence modeling[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022. ↩
-
Pang J, Zhang C, Yu H, et al. Hierarchical memory for long video modeling[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. ↩
-
Beltagy I, Peters M E, Cohan A. Longformer: The long-document transformer[J]. arXiv preprint arXiv:2004.05150, 2020. ↩
-
Zaheer M, Guruganesh G, Dubey K A, et al. Big bird: Transformers for longer sequences[J]. Advances in Neural Information Processing Systems, 2020, 33: 17283-17297. ↩