引言

现实世界中的视频往往持续数分钟甚至数小时,例如电影、体育赛事、监控录像等。**长视频理解(Long Video Understanding)**是视频分析领域最具挑战性的问题之一,需要在大量帧中捕捉长程时空依赖关系。

与短视频(通常8-32帧)不同,长视频带来了独特的挑战:计算复杂度呈线性或超线性增长、内存消耗巨大、长程依赖建模困难。本文系统介绍长视频理解的核心问题与解决方案。


长视频理解的挑战

计算复杂度问题

对于长度为 帧的视频,标准时空注意力的复杂度为:

其中 是每帧的patch数量。

问题演示

视频长度帧数 假设 Token数联合注意力 FLOPs
短视频81961,568~2.5M
中视频321966,272~39M
长视频 (1min@30fps)1800196352,800~124B
长视频 (10min@30fps)180001963,528,000~12.4T

即使是8帧的视频,全序列注意力也难以处理;30秒以上的视频几乎不可能用标准注意力处理。

内存限制

Transformer的注意力计算需要存储完整的注意力矩阵:

对于

  • 矩阵大小: 元素
  • 内存占用(float32):约 157MB
  • 对于 :约 40GB

这远远超出了GPU显存限制。

长程依赖建模困难

长视频中存在多种长程依赖:

  1. 因果依赖:视频开头的事件可能影响结尾
  2. 动作完整性:完整动作可能跨越数百帧
  3. 时序一致性:同一物体在不同时间点应保持一致
  4. 全局上下文:整体场景理解需要全视频信息

解决方案概览

长视频理解的主要技术路线:

方法类型核心思想代表工作
Memory-based使用外部记忆存储中间信息Memory Transformer, Memformer
Sparse Attention稀疏连接模式降低复杂度Longformer, BigBird
Hierarchical分层建模聚合全局信息Video Swin, 时序金字塔
Segment-based分段处理后聚合Clip-based, Segment-based
Compressive压缩信息减少tokenToken Merging, Token Learnt

Memory-based 方法

核心思想

**外部记忆(External Memory)**机制允许模型在处理长序列时”记住”之前的信息,而不是将所有历史都保存在激活中。

标准 Transformer:
┌──────────────────────────────────────┐
│  输入序列: [x₁, x₂, ..., x_T]        │
│  注意力: 直接访问所有历史位置          │
└──────────────────────────────────────┘

Memory Transformer:
┌──────────────────────────────────────┐
│  输入序列: [x₁, x₂, ..., x_T]        │
│  外部记忆: [m₁, m₂, ..., m_K]        │
│  注意力: 输入 ←→ 记忆 + 输入 ←→ 输入   │
└──────────────────────────────────────┘

Memory Transformer

Memory Transformer1 在标准Transformer中引入外部记忆模块:

class MemoryTransformer(nn.Module):
    """
    Memory Transformer: 带外部记忆的Transformer
    """
    def __init__(self, d_model, n_heads, memory_size, memory_dim):
        super().__init__()
        self.memory_size = memory_size
        
        # 记忆矩阵 (可学习或动态更新)
        self.memory = nn.Parameter(
            torch.randn(memory_size, memory_dim)
        )
        
        # 标准注意力
        self.self_attn = nn.MultiheadAttention(d_model, n_heads)
        
        # 记忆注意力 (输入查询记忆)
        self.memory_attn = nn.MultiheadAttention(d_model, n_heads)
    
    def forward(self, x):
        # x: [L, B, D] - 输入序列
        
        # 1. 自注意力 (标准)
        attn_out, _ = self.self_attn(x, x, x)
        
        # 2. 记忆注意力 (查询记忆)
        # Query来自输入, Key/Value来自记忆
        mem_out, _ = self.memory_attn(
            query=x,
            key=self.memory.expand(-1, x.size(1), -1),
            value=self.memory.expand(-1, x.size(1), -1)
        )
        
        return attn_out + mem_out

动态记忆更新

更先进的方法采用动态记忆更新策略:

class DynamicMemory(nn.Module):
    """
    动态记忆模块: 随时间更新记忆
    """
    def __init__(self, memory_size, d_model):
        super().__init__()
        self.memory_size = memory_size
        
        # 记忆状态
        self.memory = torch.zeros(memory_size, d_model)
        
        # 记忆更新门控
        self.update_gate = nn.Linear(d_model * 2, d_model)
        self.memory_proj = nn.Linear(d_model, d_model)
    
    def update(self, hidden_states, new_info):
        """
        更新记忆
        hidden_states: 当前隐藏状态
        new_info: 从输入中提取的新信息
        """
        # 聚合当前信息
        aggregated = hidden_states.mean(dim=0)  # [B, D]
        
        # 更新门控
        combined = torch.cat([self.memory, aggregated], dim=-1)
        gate = torch.sigmoid(self.update_gate(combined))
        
        # 选择性更新
        self.memory = (1 - gate) * self.memory + gate * self.memory_proj(aggregated)
        
        return self.memory

Memformer

Memformer2 是一种专门为长序列设计的高效Transformer:

  1. Memory-augmented Architecture:引入可读写的外部记忆
  2. O(NM)复杂度:其中是序列长度,是记忆大小
  3. 分层记忆:多层记忆捕获不同粒度的信息

Sparse Attention 稀疏注意力

核心思想

稀疏注意力通过限制每个token的连接模式,在保持部分全局建模能力的同时大幅降低复杂度。

标准全注意力:
┌─────────────┐
│ ● ● ● ● ● │  每个位置与所有位置连接
│ ● ● ● ● ● │  O(N²) 复杂度
│ ● ● ● ● ● │
│ ● ● ● ● ● │
│ ● ● ● ● ● │
└─────────────┘

稀疏注意力:
┌─────────────┐
│ ● ●   ●   ● │  局部连接 + 稀疏全局连接
│ ● ● ●       │  O(N·k) 复杂度 (k为邻居数)
│   ● ● ● ●   │
│ ●   ● ● ●   │
│ ●   ● ● ●   │
└─────────────┘

滑动窗口注意力

**局部注意力(Local Attention)**是最简单的稀疏模式:

class SlidingWindowAttention(nn.Module):
    """
    滑动窗口注意力: 每个位置只与窗口内的邻居交互
    """
    def __init__(self, window_size):
        self.window_size = window_size
    
    def forward(self, q, k, v):
        # q, k, v: [B, T, N, D]
        B, T, N, D = q.shape
        
        # 计算注意力分数
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
        
        # 创建因果掩码 (可选)
        mask = torch.triu(
            torch.ones(T, T, device=q.device), 
            diagonal=1
        ).bool()
        
        # 创建窗口掩码
        # 距离超过window_size的位置掩蔽
        relative_positions = torch.arange(T, device=q.device).unsqueeze(0) - \
                           torch.arange(T, device=q.device).unsqueeze(1)
        window_mask = torch.abs(relative_positions) > self.window_size
        
        # 组合掩码
        final_mask = mask | window_mask
        attn = attn.masked_fill(final_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        # Softmax
        attn = F.softmax(attn, dim=-1)
        
        return torch.matmul(attn, v)

Longformer的稀疏模式

Longformer3 提出三种稀疏注意力模式:

  1. 局部窗口注意力(Local Attention)

    • 窗口大小
    • 复杂度
  2. 全局注意力(Global Attention)

    • 特殊token(如[CLS])与所有位置交互
    • 用于聚合全局信息
  3. 随机注意力(Random Attention)

    • 每个位置随机连接个位置
    • 提供全局信息流动
class LongformerAttention(nn.Module):
    """
    Longformer稀疏注意力实现
    """
    def __init__(self, d_model, n_heads, window_size=512, num_random=256):
        super().__init__()
        self.window_size = window_size
        self.num_random = num_random
        
        # 全局token索引 (需要全局注意力的位置)
        self.global_tokens = [0]  # [CLS] token
        
        self.attention = nn.MultiheadAttention(d_model, n_heads)
    
    def forward(self, x):
        B, T, N, D = x.shape
        
        # 构建稀疏注意力模式
        attention_mode = self._get_attention_mode(T)
        
        # 根据模式计算注意力
        if attention_mode == 'local':
            # 滑动窗口注意力
            attn = self._local_attention(x)
        elif attention_mode == 'global':
            # 全局注意力
            attn = self._global_attention(x)
        elif attention_mode == 'random':
            # 随机注意力
            attn = self._random_attention(x)
        
        return attn

BigBird的稀疏模式

BigBird4 提出类似Longformer的设计,但添加了更多组件:

BigBird稀疏模式:
┌─────────────────────────┐
│ ● ●   ●   ● ●   ●   ● │  w: 窗口注意力
│ ● ● ●   ●   ● ● ●     │  g: 全局注意力  
│ ●   ● ●   ● ●   ● ●   │  r: 随机注意力
│ ● ●   ●   ● ●   ●   ● │
│ ●   ● ● ●   ●   ● ●   │
└─────────────────────────┘

复杂度分析

  • 全注意力:
  • BigBird:

Hierarchical Modeling 分层建模

核心思想

**分层建模(Hierarchical Modeling)**将长视频分解为多个层次,逐步聚合信息:

视频 → 片段 → Clip → 视频级表示
  ↓       ↓      ↓        ↓
帧级    片段级   Clip级   视频级

时序金字塔

class TemporalPyramid(nn.Module):
    """
    时序金字塔: 多尺度时序建模
    """
    def __init__(self, d_model, pyramid_levels=[1, 2, 4]):
        super().__init__()
        self.pyramid_levels = pyramid_levels
        
        # 不同尺度的聚合器
        self.aggregators = nn.ModuleDict({
            str(level): TemporalAggregator(d_model, level)
            for level in pyramid_levels
        })
        
        # 融合层
        self.fusion = nn.Linear(d_model * len(pyramid_levels), d_model)
    
    def forward(self, x):
        """
        x: [B, T, N, D] - 时空特征
        """
        multi_scale_features = []
        
        for level in self.pyramid_levels:
            # 按level分组
            pooled = self.aggregators[str(level)](x, level)
            multi_scale_features.append(pooled)
        
        # 多尺度特征融合
        fused = torch.cat(multi_scale_features, dim=-1)
        return self.fusion(fused)
    
    def forward_with_grouping(self, x, level):
        """
        按level对时间维度分组聚合
        """
        B, T, N, D = x.shape
        
        # 确保T能被level整除
        T_padded = ((T + level - 1) // level) * level
        if T_padded > T:
            x = F.pad(x, (0, 0, 0, 0, 0, T_padded - T))
        
        # 重塑: [B, T//level, level, N, D]
        x = x.view(B, T_padded // level, level, N, D)
        
        # 聚合: [B, T//level, N, D]
        return x.mean(dim=2)

分段处理与聚合

class SegmentBasedVideoModel(nn.Module):
    """
    基于分段的视频模型: 分段处理后聚合
    """
    def __init__(self, segment_encoder, segment_aggregator):
        super().__init__()
        self.segment_encoder = segment_encoder  # 处理单个片段
        self.segment_aggregator = segment_aggregator  # 聚合片段特征
    
    def forward(self, video, num_segments=8):
        """
        video: [B, T, C, H, W]
        """
        B, T, C, H, W = video.shape
        
        # 均匀分段
        segment_length = T // num_segments
        segments = video.reshape(B, num_segments, segment_length, C, H, W)
        
        # 逐段编码
        segment_features = []
        for i in range(num_segments):
            feat = self.segment_encoder(segments[:, i])
            segment_features.append(feat)
        
        # Stack: [B, num_segments, D]
        segment_features = torch.stack(segment_features, dim=1)
        
        # 聚合为视频级表示
        video_repr = self.segment_aggregator(segment_features)
        
        return video_repr

滑动窗口方法

class SlidingWindowVideoModel(nn.Module):
    """
    滑动窗口视频模型: 窗口滑动 + 聚合
    """
    def __init__(self, encoder, aggregator, window_size=16, stride=8):
        super().__init__()
        self.window_size = window_size
        self.stride = stride
        self.encoder = encoder
        self.aggregator = aggregator
    
    def forward(self, video):
        """
        video: [B, T, C, H, W]
        """
        B, T, C, H, W = video.shape
        window_features = []
        
        # 滑动窗口
        for start in range(0, T - self.window_size + 1, self.stride):
            end = start + self.window_size
            window = video[:, start:end]
            
            # 编码窗口
            feat = self.encoder(window)
            window_features.append(feat)
        
        # 处理尾部 (如果视频长度不是stride的整数倍)
        if (T - self.window_size) % self.stride != 0:
            window = video[:, -self.window_size:]
            feat = self.encoder(window)
            window_features.append(feat)
        
        # Stack: [B, num_windows, D]
        window_features = torch.stack(window_features, dim=1)
        
        # 聚合窗口特征
        return self.aggregator(window_features)

Token 稀疏化

空间Token稀疏化

减少每帧的token数量:

class TokenMerging(nn.Module):
    """
    Token合并: 减少空间token数量
    """
    def forward(self, x, merge_ratio=2):
        # x: [B, T, N, D]
        B, T, N, D = x.shape
        
        # 重塑为网格
        H = W = int(math.sqrt(N))
        x = x.view(B, T, H, W, D)
        
        # 合并相邻token
        x = x.reshape(B, T, H // merge_ratio, merge_ratio, 
                      W // merge_ratio, merge_ratio, D)
        x = x.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
        x = x.reshape(B, T, H // merge_ratio, 
                      W // merge_ratio, merge_ratio * merge_ratio * D)
        
        # 线性投影恢复维度
        return self.proj(x)

时间Token稀疏化

减少时间维度的token:

class TemporalTokenSampling(nn.Module):
    """
    时间Token采样: 动态选择重要帧
    """
    def __init__(self, d_model):
        super().__init__()
        self.importance_scorer = nn.Linear(d_model, 1)
    
    def forward(self, x, num_samples=None):
        # x: [B, T, N, D]
        B, T, N, D = x.shape
        
        # 计算每帧的重要性分数
        # 聚合空间维度
        x_pooled = x.mean(dim=2)  # [B, T, D]
        importance = self.importance_scorer(x_pooled).squeeze(-1)  # [B, T]
        
        # 软采样 (使用注意力权重)
        if num_samples is not None:
            weights = F.softmax(importance, dim=1)
            
            # 展开为权重分布
            weights = weights.unsqueeze(-1).unsqueeze(-1)  # [B, T, 1, 1]
            sampled = (x * weights).sum(dim=1)  # [B, N, D]
            
            return sampled, weights
        else:
            return importance

位置编码扩展

相对位置编码

对于长视频,相对位置编码比绝对位置编码更适合:

class ExtendedRelativePosition(nn.Module):
    """
    扩展相对位置编码: 支持长序列
    """
    def __init__(self, max_len, num_heads):
        super().__init__()
        self.max_len = max_len
        self.num_heads = num_heads
        
        # 相对位置偏置表
        # 使用对数间隔,支持更长距离
        self.relative_attention_bias = nn.Parameter(
            torch.zeros(2 * max_len - 1, num_heads)
        )
        
        # 可学习的距离衰减
        self.distance_decay = nn.Parameter(torch.ones(1))
    
    def forward(self, seq_len):
        # 生成相对位置索引
        position_ids = torch.arange(seq_len)
        relative_pos = position_ids.unsqueeze(0) - position_ids.unsqueeze(1)
        relative_pos = relative_pos + self.max_len - 1
        
        # 对数距离变换
        log_distance = torch.log(1 + torch.abs(relative_pos.float()))
        scaled_pos = (log_distance * self.distance_decay).long()
        scaled_pos = torch.clamp(scaled_pos, 0, 2 * self.max_len - 2)
        
        # 获取偏置
        relative_bias = self.relative_attention_bias[scaled_pos]
        
        return relative_bias.permute(2, 0, 1)  # [num_heads, seq_len, seq_len]

方法对比与选择

复杂度对比

方法时间复杂度空间复杂度长程建模实现复杂度
标准注意力最强
滑动窗口
Memory-based中等
分层建模
稀疏注意力中等

其中 是窗口大小, 是记忆大小, 是局部窗口, 是随机连接数。

选择指南

视频长度推荐方法理由
< 1分钟标准Transformer足够处理
1-10分钟滑动窗口 + 分层简单有效
10-60分钟Memory-based平衡效率和建模
> 1小时稀疏注意力 + 分层必须稀疏化

总结

长视频理解是视频分析领域的重要挑战,本文介绍了四种主要的解决方案:

  1. Memory-based方法:通过外部记忆存储和检索信息,降低激活内存
  2. Sparse Attention:通过稀疏连接模式,将复杂度降为
  3. Hierarchical Modeling:分层聚合多尺度信息,捕获不同粒度的依赖
  4. Token稀疏化:减少输入token数量,降低总体计算量

这些方法可以组合使用,例如在分层建模的每一层使用滑动窗口注意力,以进一步降低计算复杂度。


参考文献

Footnotes

  1. Wu K, Peng H, Zhou Z, et al. Memformer: A memory-augmented transformer for efficient long-range sequence modeling[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022.

  2. Pang J, Zhang C, Yu H, et al. Hierarchical memory for long video modeling[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

  3. Beltagy I, Peters M E, Cohan A. Longformer: The long-document transformer[J]. arXiv preprint arXiv:2004.05150, 2020.

  4. Zaheer M, Guruganesh G, Dubey K A, et al. Big bird: Transformers for longer sequences[J]. Advances in Neural Information Processing Systems, 2020, 33: 17283-17297.