概述

视频检索增强生成(Video Retrieval Augmented Generation, VRAG)是 Chen 等人在 NeurIPS 2025 上提出的一种创新方法,旨在解决交互式视频世界模型中的核心挑战。1 该方法通过显式全局状态条件化和记忆检索机制,显著减少自回归视频生成中的复合误差(compounding errors)问题,同时提升世界模型的时空一致性(spatio-temporal consistency)

VRAG 的核心洞察在于:与大型语言模型(LLM)不同,当前的视频扩散模型缺乏强大的上下文学习能力,这使得直接应用 RAG 技术或扩展上下文窗口的方法效果有限。VRAG 通过一系列针对性训练策略改进,使视频模型能够有效利用历史帧信息进行一致的长时序视频生成。

相关工作:VRAG 与 世界模型规划推理融合 密切相关,因为视频世界模型是实现物理世界模拟和智能规划的基础。此外,VRAG 也涉及 长视频理解 中的长程依赖建模问题。


1. 问题背景

1.1 世界模型的核心要求

基础世界模型(Foundational World Models)需要同时满足两个核心要求:

要求描述重要性
可交互性(Interactivity)能够根据动作序列进行条件化生成支持智能体的决策与规划
时空一致性(Spatio-temporal Consistency)在长时间跨度内保持物体身份、空间布局和世界状态的连贯性实现可信的环境模拟

近年来,视频扩散模型在图像和视频生成方面取得了显著进展。然而,将这些模型扩展为支持动作条件化的长时序交互式世界模型仍然面临严峻挑战。1

1.2 自回归视频生成的两大根本性问题

基于自回归范式的视频生成方法面临两个相互耦合的根本性限制:

问题一:复合误差(Compounding Errors)

复合误差是指在自回归生成过程中,早期帧的微小预测误差会随时间累积,导致生成的视频与合理的未来状态产生显著偏离。

这种误差具有以下特点:

  • 累积性:每一帧的误差都会传递给下一帧,形成误差链
  • 不可逆性:一旦产生偏差,后续帧难以自我纠正
  • 本质性:作者论证这可能是自回归范式固有的问题,难以完全消除

问题二:记忆机制不足(Insufficient Memory)

自回归模型在保持长期一致性方面存在根本性缺陷:

  • 物体身份丢失:难以在整个视频序列中保持一致的对象标识
  • 空间布局漂移:场景中的物体位置关系逐渐偏离初始状态
  • 世界状态不一致:环境属性(如光照、天气)难以维持连续性

这两个问题往往相互加剧,形成恶性循环,使得长时序一致性生成变得极为困难。1


2. 技术方法

2.1 基础框架:潜在视频扩散模型

VRAG 构建在**潜在视频扩散模型(Latent Video Diffusion Model)**的基础之上。给定输入视频序列 ,首先使用预训练的变分自编码器(VAE)将其编码为潜在表示

前向扩散过程按照方差调度 逐渐向潜在表示添加高斯噪声:

模型通过预测每一步的噪声 来学习去噪过程:

2.2 动作条件化(Action Conditioning)

为实现交互式视频生成,VRAG 通过**自适应层归一化(Adaptive Layer Normalization, AdaLN)**引入动作条件。具体而言:

给定动作序列 ,首先通过可学习的嵌入层将其转换:

对于扩散模型中的每个归一化层,通过线性投影学习动作依赖的缩放和偏移参数:

最终的自适应层归一化定义为:

其中 表示中间特征图, 表示逐元素乘法。

2.3 自回归视频生成:扩散强制(Diffusion Forcing)

为实现长视频生成,VRAG 采用自回归方法,每步基于固定长度的上下文窗口 生成后续帧。在训练阶段,应用**扩散强制(Diffusion Forcing)**技术:

对输入视频序列的每一帧,根据扩散调度独立添加不同水平的噪声:

其中 表示第 帧的带噪潜在表示。这迫使模型能够鲁棒地处理上下文帧中的噪声,防止其过度依赖上下文信息。

动作条件化自回归视频模型的训练目标定义为:

2.4 架构:时空扩散 Transformer

VRAG 采用**时空扩散 Transformer(SpatioTemporal DiT)**作为核心架构:

  • 空间-时间 DiT 块:包含分离的空间注意力和时间注意力模块
  • 旋转位置编码(RoPE):在空间和时间注意力中均应用 RoPE
  • 因果时间注意力:时间注意力采用因果掩码,确保自回归特性

3. VRAG 核心方法

3.1 全局状态条件化(Global State Conditioning)

为增强空间一致性,VRAG 引入全局状态信息作为额外条件信号。全局状态向量 包含两个关键分量:

  • :表示角色的 3D 位置坐标
  • :表示角色的方向角度

给定动作序列 和全局状态序列 ,通过可学习的嵌入层进行变换:

生成的嵌入特征输入到扩散模型的 AdaLN 层,实现空间感知的视频生成。

3.2 视频检索增强生成(VRAG)

超越基础条件化,VRAG 提出记忆检索增强生成,以增强模型利用历史上下文同时维持时间一致性的能力。

3.2.1 检索机制

历史帧从固定长度的缓冲区 中检索,缓冲区存储先前生成的帧。基于全局状态的相似度评分定义为:

其中:

  • 是距离度量(如欧氏距离)
  • 是待预测帧的全局状态
  • 是权重向量,调节不同状态组件的重要性

选择与当前状态最相似的 个历史帧及其状态,组成检索上下文。

3.2.2 VRAG 与标准 RAG 的关键差异

与 LLM 中的 RAG 不同,视频扩散模型的上下文学习能力较弱,直接将历史帧作为上下文进行推理效果不佳。VRAG 通过以下四项关键改进实现有效的记忆增强视频生成:

改进描述目的
时间偏移(Temporal Offset)为检索帧的 RoPE 嵌入添加时间偏移量 区分检索帧与正常上下文帧
部分去噪(Partial Denoising)对检索帧应用较低噪声水平 模拟推理时部分去噪的历史帧,增强鲁棒性
损失掩码(Loss Masking)对检索帧掩码扩散损失 确保模型专注于当前上下文去噪
选择性条件化(Selective Conditioning)检索帧仅基于全局状态 ,掩码动作条件 避免动作序列的时间不连续性

3.2.3 训练目标

VRAG 的训练目标定义为:

其中:


4. 与其他方法的对比分析

4.1 为什么朴素方法效果不佳?

4.1.1 扩展上下文窗口的局限性

受 LLM 成功经验的启发,研究者尝试扩展视频模型的上下文窗口。然而:

  • 计算开销巨大:更长的上下文窗口带来超线性增长的计算和内存开销
  • 视频模型缺乏强上下文学习能力:视频扩散模型不像 LLM 那样具备强大的上下文学习能力
  • 更长上下文对一致性改善有限:单纯增加上下文长度难以解决根本的一致性问题

4.1.2 标准 RAG 的局限性

**检索增强生成(RAG)**在 LLM 中被证明是引入外部知识的有效技术。然而:

  • 视频模型上下文学习能力弱:视频扩散模型难以有效利用检索到的历史帧作为条件
  • 像素级表示的内在局限性:仅从像素或潜在表示中隐式学习世界一致性是不够的
  • 缺乏显式状态建模:没有显式的全局状态条件化,模型难以建立稳定的世界表示

4.1.3 神经记忆增强的局限

研究者还尝试了基于神经记忆模块的方法,但同样面临:

  • 记忆整合困难:如何有效整合外部记忆信息到视频生成过程
  • 长期依赖建模不足:标准注意力机制难以捕捉超长距离的依赖关系

4.2 方法对比总结

方法原理优势局限性
VRAG(本文)显式全局状态条件化 + 检索增强生成显著减少复合误差,提升时空一致性需要额外的状态标注
扩散强制(Diffusion Forcing)每帧独立噪声调度支持自回归采样,无需教师强制复合误差仍然存在
历史缓冲区(History Buffer)存储并利用历史帧简单直接上下文学习能力不足
YaRN 上下文扩展位置编码插值扩展上下文窗口视频模型上下文学习能力弱
纯动作条件化仅基于动作序列支持交互式生成缺乏全局一致性建模

5. PyTorch 实现

以下是基于 VRAG 核心思想的简化 PyTorch 实现示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import Optional, Tuple
 
class VRAGModel(nn.Module):
    """
    VRAG: Video Retrieval Augmented Generation Model
    
    核心组件:
    1. 时空扩散 Transformer (SpatioTemporal DiT)
    2. 全局状态条件化 (Global State Conditioning)
    3. 检索增强注意力 (Retrieval-Augmented Attention)
    """
    
    def __init__(
        self,
        latent_dim: int = 16,
        hidden_size: int = 1024,
        num_heads: int = 16,
        num_layers: int = 12,
        action_dim: int = 10,
        state_dim: int = 6,  # 位置(3) + 朝向(3)
        retrieve_num: int = 10,
        context_frames: int = 20,
    ):
        super().__init__()
        
        self.retrieve_num = retrieve_num
        self.context_frames = context_frames
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # 嵌入层
        self.action_embed = nn.Linear(action_dim, hidden_size)
        self.state_embed = nn.Linear(state_dim, hidden_size)
        self.time_embed = TimestepEmbedder(hidden_size)
        
        # 外部条件嵌入(动作 + 状态)
        self.external_cond_dim = action_dim + state_dim
        self.external_cond_proj = nn.Linear(
            self.external_cond_dim, hidden_size
        )
        
        # 空间-时间 DiT 块
        self.blocks = nn.ModuleList([
            SpatioTemporalDiTBlock(
                hidden_size=hidden_size,
                num_heads=num_heads,
                retrieve_num=retrieve_num,
            )
            for _ in range(num_layers)
        ])
        
        # 输出层
        self.final_layer = FinalLayer(hidden_size, latent_dim)
        
    def forward(
        self,
        x: torch.Tensor,                    # (B, T, C, H, W) 当前上下文帧
        t: torch.Tensor,                    # (B, T,) 扩散时间步
        actions: torch.Tensor,              # (B, T, A) 动作序列
        states: torch.Tensor,               # (B, T, S) 全局状态序列
        retrieved_frames: Optional[torch.Tensor] = None,  # (B, L_h, C, H, W) 检索帧
        retrieved_states: Optional[torch.Tensor] = None,  # (B, L_h, S) 检索状态
    ) -> torch.Tensor:
        """
        VRAG 前向传播
        
        Args:
            x: 当前上下文帧的潜在表示
            t: 扩散时间步
            actions: 当前帧的动作序列
            states: 当前帧的全局状态
            retrieved_frames: 检索到的历史帧
            retrieved_states: 检索帧对应的全局状态
            
        Returns:
            预测的噪声或速度场
        """
        B, T, C, H, W = x.shape
        
        # 步骤1: 构建扩展上下文
        if retrieved_frames is not None:
            # 检索帧与当前帧拼接
            x = torch.cat([retrieved_frames, x], dim=1)  # (B, L_h+T, C, H, W)
            states = torch.cat([retrieved_states, states], dim=1)  # (B, L_h+T, S)
            # 动作序列中检索部分置零(选择性条件化)
            actions_retrieved = torch.zeros(
                B, retrieved_frames.shape[1], self.action_dim,
                device=x.device, dtype=x.dtype
            )
            actions = torch.cat([actions_retrieved, actions], dim=1)
            t = torch.cat([
                t[:, :1].expand(-1, retrieved_frames.shape[1]), t
            ], dim=1)
        
        # 步骤2: 嵌入时间步
        t_emb = self.time_embed(t.flatten())  # (B*T,)
        t_emb = rearrange(t_emb, "(b t) d -> b t d", b=B)
        
        # 步骤3: 嵌入外部条件(动作 + 状态)
        cond = torch.cat([actions, states], dim=-1)  # (B, T_total, A+S)
        cond_emb = self.external_cond_proj(cond)
        cond_emb = rearrange(cond_emb, "b t d -> b t d")
        
        # 步骤4: 组合条件
        c = t_emb + cond_emb  # (B, T_total, D)
        
        # 步骤5: 通过 DiT 块处理
        x = rearrange(x, "b t c h w -> b t h w c")
        for block in self.blocks:
            x = block(x, c)
        
        # 步骤6: 输出层
        x = self.final_layer(x, c)
        x = rearrange(x, "b t h w c -> b t c h w")
        
        # 如果有检索帧,只返回当前上下文部分
        if retrieved_frames is not None:
            x = x[:, retrieved_frames.shape[1]:]
        
        return x
 
 
class SpatioTemporalDiTBlock(nn.Module):
    """
    时空扩散 Transformer 块
    
    包含:
    - 空间注意力 (Spatial Attention)
    - 时间注意力 (Temporal Attention) 带检索增强
    - AdaLN 条件化
    """
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        retrieve_num: int = 10,
        mlp_ratio: float = 4.0,
    ):
        super().__init__()
        self.retrieve_num = retrieve_num
        
        # 空间模块
        self.s_norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.s_attn = SpatialAttention(hidden_size, num_heads)
        self.s_norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.s_mlp = Mlp(hidden_size, int(hidden_size * mlp_ratio))
        self.s_adaLN = nn.Sequential(
            nn.SiLU(), 
            nn.Linear(hidden_size, 6 * hidden_size)
        )
        
        # 时间模块(带检索增强)
        self.t_norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.t_attn = TemporalRetrievalAttention(
            hidden_size, num_heads, retrieve_num=retrieve_num
        )
        self.t_norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.t_mlp = Mlp(hidden_size, int(hidden_size * mlp_ratio))
        self.t_adaLN = nn.Sequential(
            nn.SiLU(), 
            nn.Linear(hidden_size, 6 * hidden_size)
        )
        
    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, T, H, W, D) 输入特征
            c: (B, T, D) 条件嵌入
            
        Returns:
            (B, T, H, W, D) 输出特征
        """
        B, T, H, W, D = x.shape
        
        # === 空间块 ===
        # 自适应层归一化
        s_shift, s_scale, s_gate, mlp_shift, mlp_scale, mlp_gate = \
            self.s_adaLN(c).chunk(6, dim=-1)
        
        # 空间注意力 + 门控
        h = self.s_norm1(x)
        h = modulate(h, s_shift, s_scale)  # 应用条件化
        h = self.s_attn(h)
        x = x + s_gate * h
        
        # 空间 MLP
        h = self.s_norm2(x)
        h = modulate(h, mlp_shift, mlp_scale)
        h = self.s_mlp(h)
        x = x + mlp_gate * h
        
        # === 时间块 ===
        t_shift, t_scale, t_gate, mlp_shift, mlp_scale, mlp_gate = \
            self.t_adaLN(c).chunk(6, dim=-1)
        
        # 时间注意力(检索增强)+ 门控
        h = self.t_norm1(x)
        h = modulate(h, t_shift, t_scale)
        h = self.t_attn(h)
        x = x + t_gate * h
        
        # 时间 MLP
        h = self.t_norm2(x)
        h = modulate(h, mlp_shift, mlp_scale)
        h = self.t_mlp(h)
        x = x + mlp_gate * h
        
        return x
 
 
class TemporalRetrievalAttention(nn.Module):
    """
    带检索增强的时间注意力
    
    关键特性:
    1. 为检索帧添加时间偏移以区分来源
    2. 检索帧与当前帧使用不同的位置编码
    """
    
    def __init__(
        self,
        dim: int,
        num_heads: int,
        retrieve_num: int = 10,
        temporal_offset: int = 100,
    ):
        super().__init__()
        self.retrieve_num = retrieve_num
        self.temporal_offset = temporal_offset
        self.head_dim = dim // num_heads
        
        self.to_qkv = nn.Linear(dim, dim * 3)
        self.to_out = nn.Linear(dim, dim)
        
        self.rope = RotaryEmbedding(dim=self.head_dim)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, T, H, W, D) 特征
            
        Returns:
            (B, T, H, W, D) 注意力输出
        """
        B, T, H, W, D = x.shape
        
        # 重塑为 (B*H*W, T, D)
        x = rearrange(x, "b t h w d -> (b h w) t d")
        
        q, k, v = self.to_qkv(x).chunk(3, dim=-2)
        
        # 分离检索帧和当前帧
        retrieve_q = q[:, :self.retrieve_num]
        retrieve_k = k[:, :self.retrieve_num]
        condition_q = q[:, self.retrieve_num:]
        condition_k = k[:, self.retrieve_num:]
        
        # 应用旋转位置编码(检索帧添加偏移)
        retrieve_q = self.rope.rotate(retrieve_q, offset=self.temporal_offset)
        retrieve_k = self.rope.rotate(retrieve_k, offset=self.temporal_offset)
        condition_q = self.rope.rotate(condition_q)
        condition_k = self.rope.rotate(condition_k)
        
        # 拼接
        q = torch.cat([retrieve_q, condition_q], dim=1)
        k = torch.cat([retrieve_k, condition_k], dim=1)
        
        # 缩放点积注意力(因果)
        attn = F.scaled_dot_product_attention(
            q, k, v, is_causal=True
        )
        
        # 恢复形状
        attn = rearrange(attn, "(b h w) t d -> b t h w d", b=B, h=H, w=W)
        
        return self.to_out(attn)
 
 
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """
    自适应层归一化的调制操作
    
    Args:
        x: (..., D) 输入特征
        shift: (..., D) 偏移参数
        scale: (..., D) 缩放参数
        
    Returns:
        调制后的特征
    """
    return x * (1 + scale.unsqueeze(-2)) + shift.unsqueeze(-2)
 
 
class TimestepEmbedder(nn.Module):
    """时间步嵌入器"""
    
    def __init__(self, hidden_size: int):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.SiLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.mlp(t)
        return t_emb
 
 
class Mlp(nn.Module):
    """MLP 模块"""
    
    def __init__(self, in_features: int, hidden_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
 
 
class RotaryEmbedding(nn.Module):
    """旋转位置编码"""
    
    def __init__(self, dim: int, max_freq: float = 256):
        super().__init__()
        self.dim = dim
        
        # 生成频率
        freqs = 1.0 / (max_freq ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("freqs", freqs)
        
    def rotate(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
        """
        应用旋转位置编码
        
        Args:
            x: (N, T, D) 查询或键
            offset: 位置偏移量
            
        Returns:
            旋转后的嵌入
        """
        N, T, D = x.shape
        x = rearrange(x, "n t (d r) -> n t d r", r=2)
        
        # 应用旋转
        angle = self.freqs * torch.arange(T, device=x.device).unsqueeze(-1)
        angle = angle + offset  # 添加偏移
        
        cos = angle.cos()
        sin = angle.sin()
        
        x1, x2 = x.unbind(-1)
        rx1 = x1 * cos - x2 * sin
        rx2 = x1 * sin + x2 * cos
        
        x = torch.stack([rx1, rx2], dim=-1)
        x = rearrange(x, "n t d r -> n t (d r)")
        
        return x
 
 
def retrieval_similarity_search(
    history_states: torch.Tensor,      # (B, N, S) 历史状态序列
    current_state: torch.Tensor,       # (B, S) 当前状态
    weight: torch.Tensor,              # (S,) 状态权重
    k: int = 10,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    基于状态的相似度检索
    
    基于全局状态的加权相似度,从历史序列中检索最相关的帧。
    
    Args:
        history_states: 历史帧的状态序列
        current_state: 当前帧的状态
        weight: 状态组件的权重向量
        k: 检索帧数量
        
    Returns:
        (indices, scores): 检索帧索引和相似度分数
    """
    # 加权状态
    weighted_history = history_states * weight.unsqueeze(0).unsqueeze(-1)
    weighted_current = current_state * weight
    
    # 计算欧氏距离
    distances = torch.norm(
        weighted_history - weighted_current.unsqueeze(1), 
        dim=-1
    )
    
    # 选择距离最小的 k 个
    scores, indices = torch.topk(-distances, k=k, dim=1)
    
    return indices, -scores  # 返回正的距离分数

5.1 推理流程

@torch.no_grad()
def vrag_inference(
    model: VRAGModel,
    vae: nn.Module,
    initial_frames: torch.Tensor,      # (1, 100, 3, H, W) 初始帧
    actions: torch.Tensor,              # (1, T, A) 动作序列
    states: torch.Tensor,               # (1, T, S) 状态序列
    history_buffer: list,               # 历史帧缓冲区
    num_generation_steps: int = 200,
    cfg: float = 1.0,
) -> torch.Tensor:
    """
    VRAG 推理流程
    
    Args:
        model: VRAG 模型
        vae: VAE 编码器/解码器
        initial_frames: 初始帧序列
        actions: 动作序列
        states: 全局状态序列
        history_buffer: 历史帧缓冲区
        num_generation_steps: 生成帧数
        cfg: 分类器自由引导强度
        
    Returns:
        生成的视频序列
    """
    device = model.device
    generated_frames = []
    
    # 初始化:编码初始帧
    current_latent = vae_encode(vae, initial_frames)
    
    for step in range(num_generation_steps):
        # 1. 从历史缓冲区检索相关帧
        retrieved_frames, retrieved_states = retrieve_from_buffer(
            history_buffer,
            current_state=states[:, step],
            k=model.retrieve_num,
            weight=get_state_weight(),
        )
        
        # 2. 准备模型输入
        # 添加部分噪声以模拟推理过程
        noise_level = compute_noise_level(step, total_steps=num_generation_steps)
        current_noisy = add_noise(current_latent, noise_level)
        
        # 3. 模型预测
        noise_pred = model(
            x=current_noisy,
            t=torch.full((1, 1), noise_level, device=device),
            actions=actions[:, step:step+1],
            states=states[:, step:step+1],
            retrieved_frames=retrieved_frames,
            retrieved_states=retrieved_states,
        )
        
        # 4. 分类器自由引导(可选)
        if cfg > 1.0:
            noise_pred_uncond = model(
                x=current_noisy,
                t=torch.full((1, 1), noise_level, device=device),
                actions=torch.zeros_like(actions[:, step:step+1]),  # 无条件
                states=states[:, step:step+1],
                retrieved_frames=retrieved_frames,
                retrieved_states=retrieved_states,
            )
            noise_pred = (1 + cfg) * noise_pred - cfg * noise_pred_uncond
        
        # 5. 去噪步骤
        current_latent = denoise_step(
            current_noisy, 
            noise_pred, 
            noise_level
        )
        
        # 6. 解码并保存
        frame = vae_decode(vae, current_latent)
        generated_frames.append(frame)
        
        # 7. 更新历史缓冲区
        update_history_buffer(history_buffer, frame, states[:, step])
    
    return torch.cat(generated_frames, dim=1)

6. 实验结果与分析

6.1 实验设置

VRAG 在 Minecraft 游戏环境中进行评估,该环境具有以下特点:

  • 长时序依赖:需要保持数百帧的时空一致性
  • 动作条件化:角色位置和朝向影响视频生成
  • 交互性:支持不同动作序列产生不同结果

6.2 主要结果

实验结果表明,VRAG 在以下方面显著优于基线方法:

指标描述VRAG vs 基线
时空一致性物体身份、空间布局的保持程度显著提升
复合误差减少长期生成的误差累积程度显著减少
视觉质量生成视频的客观质量指标保持或提升

6.3 消融实验分析

通过消融实验,验证了各组件的贡献:

  • 全局状态条件化:对空间一致性贡献最大
  • 检索增强机制:对减少长期复合误差有效
  • 选择性条件化(掩码动作):防止动作序列不连续

7. 局限性

尽管 VRAG 取得了显著进展,但仍存在以下局限性:

7.1 当前视频模型的上下文学习能力不足

与 LLM 相比,当前的视频扩散模型在上下文学习方面能力较弱。这限制了一些直接应用 LLM 技术的效果,如简单的上下文窗口扩展或标准 RAG。

7.2 全局状态标注的依赖

VRAG 依赖显式的全局状态标注(如角色位置、朝向)。在缺乏此类标注的场景中,需要额外的状态估计模块。

7.3 推理效率

检索机制的引入增加了推理计算量,在实时应用中需要进一步优化。

7.4 泛化能力

当前实验主要在 Minecraft 环境中进行,在其他类型环境中的泛化能力需要进一步验证。


8. 相关研究方向

VRAG 与多个前沿研究领域密切相关:


参考文献


本文档基于 arXiv:2505.21996 论文内容编写,涵盖 VRAG 的核心技术方法、实现细节和实验分析。

Footnotes

  1. Chen, T., Hu, X., Ding, Z., & Jin, C. (2025). Learning World Models for Interactive Video Generation. Advances in Neural Information Processing Systems (NeurIPS 2025). arXiv:2505.21996 | Project Page | GitHub 2 3