VID2World:从视频扩散到交互世界模型

1. 背景与动机

视频生成模型能够产生视觉逼真的视频,但它们缺乏交互性:无法根据动作输入预测世界如何响应。VID2World1 首次系统性地研究了如何将预训练视频扩散模型转化为可交互的世界模型

2. 核心问题

2.1 视频生成 vs 世界模型

特性视频生成模型世界模型
输入文本/图像文本/图像 + 动作序列
输出固定视频条件视频(随动作变化)
因果性有(动作→结果)
交互性

2.2 核心挑战

  1. 动作空间定义:如何表示动作?
  2. 条件注入:如何将动作信息融入扩散模型?
  3. 时序一致性:如何保证动作响应的因果性?

3. 技术方法

3.1 整体框架

历史帧 + 动作序列
        ↓
┌─────────────────────────────────────┐
│        动作编码器 (Action Encoder)     │
│        (将动作映射到特征空间)          │
└────────────────┬────────────────────┘
                 ↓
┌─────────────────────────────────────┐
│   动作-视频交叉注意力 (Action-Video    │
│   Cross-Attention)                   │
└────────────────┬────────────────────┘
                 ↓
┌─────────────────────────────────────┐
│      视频扩散骨干 (Video Diffusion)    │
│      (预训练模型,冻结或微调)          │
└────────────────┬────────────────────┘
                 ↓
           未来视频预测

3.2 动作表示

VID2World 定义了三种动作表示:

class ActionRepresentation:
    """动作表示类型"""
    
    # 1. 离散动作:分类任务
    DISCRETE = "discrete"  # 如:left, right, forward, backward
    
    # 2. 连续动作:控制任务
    CONTINUOUS = "continuous"  # 如:[vx, vy, vz, rx, ry, rz]
    
    # 3. 状态差异:目标导向
    STATE_DELTA = "state_delta"  # 如:[Δx, Δy, Δz]

3.3 动作编码器

class ActionEncoder(nn.Module):
    """动作编码器 - 将动作映射到扩散模型空间"""
    
    def __init__(self, action_dim, embed_dim, num_layers=4):
        super().__init__()
        
        # 动作 embedding
        self.action_embed = nn.Sequential(
            nn.Linear(action_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # 时序 Transformer(处理动作序列)
        self.temporal_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, 
                nhead=8,
                dim_feedforward=embed_dim * 4
            ),
            num_layers=num_layers
        )
        
        # 上采样到视频特征分辨率
        self.upsampler = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim//2, 4, stride=2, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim//2, embed_dim//4, 4, stride=2, padding=1),
        )
    
    def forward(self, actions):
        """
        Args:
            actions: [B, T, action_dim] 动作序列
        Returns:
            action_features: [B, C, H, W] 空间化的动作特征
        """
        # 动作 embedding
        emb = self.action_embed(actions)  # [B, T, D]
        
        # 时序建模
        temporal_out = self.temporal_transformer(emb)  # [B, T, D]
        
        # 重排为空间特征
        # 假设动作发生在第一帧
        action_feature = temporal_out[:, 0]  # [B, D]
        action_spatial = action_feature.unsqueeze(-1).unsqueeze(-1)  # [B, D, 1, 1]
        
        # 上采样到视频分辨率
        action_features = self.upsampler(action_spatial.expand(-1, -1, 8, 8))
        
        return action_features

3.4 动作-视频交叉注意力

class ActionVideoCrossAttention(nn.Module):
    """动作-视频交叉注意力"""
    
    def __init__(self, video_dim, action_dim, num_heads=8):
        super().__init__()
        self.video_proj = nn.Linear(video_dim, video_dim)
        self.action_proj = nn.Linear(action_dim, video_dim)
        self.cross_attn = nn.MultiheadAttention(
            video_dim, num_heads, batch_first=True
        )
        self.norm = nn.LayerNorm(video_dim)
    
    def forward(self, video_features, action_features):
        """
        Args:
            video_features: [B*T, N, D] 视频特征
            action_features: [B, D, H, W] 动作特征
        Returns:
            fused: 融合后的特征
        """
        B, D, H, W = action_features.shape
        T = video_features.shape[0] // B
        
        # 空间化动作特征
        action_spatial = action_features.view(B, D, -1).permute(0, 2, 1)  # [B, H*W, D]
        action_spatial = action_spatial.unsqueeze(1).expand(-1, T, -1, -1)  # [B, T, H*W, D]
        action_spatial = action_spatial.reshape(B*T, H*W, D)  # [B*T, H*W, D]
        
        # 交叉注意力
        fused, _ = self.cross_attn(
            video_features,  # query
            action_spatial,  # key
            action_spatial   # value
        )
        
        fused = self.norm(video_features + fused)
        return fused

3.5 训练策略

VID2World 采用两阶段训练:

Stage 1: 动作嵌入学习

仅训练动作编码器,让视频扩散模型保持冻结:

def stage1_loss(model, batch):
    """Stage 1: 学习动作嵌入"""
    video, actions, future_video = batch
    
    # 提取视频特征(冻结)
    with torch.no_grad():
        video_features = model.extract_features(video)
    
    # 编码动作
    action_features = model.action_encoder(actions)
    
    # 交叉注意力融合
    fused = model.cross_attention(video_features, action_features)
    
    # 预测未来帧
    pred_future = model.diffusion(fused)
    
    # 仅优化动作编码器
    loss = F.mse_loss(pred_future, future_video)
    loss.backward()
    model.action_encoder_optimizer.step()
    
    return loss

Stage 2: 联合微调

解冻视频扩散模型进行联合优化:

def stage2_loss(model, batch):
    """Stage 2: 联合微调"""
    video, actions, future_video = batch
    
    video_features = model.extract_features(video)
    action_features = model.action_encoder(actions)
    fused = model.cross_attention(video_features, action_features)
    pred_future = model.diffusion(fused)
    
    # 视频重建损失
    loss_video = F.mse_loss(pred_future, future_video)
    
    # 因果损失:确保动作与变化对应
    loss_causal = model.causal_consistency_loss(pred_future, future_video, actions)
    
    loss = loss_video + lambda_causal * loss_causal
    loss.backward()
    
    # 优化所有参数
    model.optimizer.step()
    
    return loss

4. 实验结果

4.1 交互式预测质量

方法FPS ↑FID ↓Action Consistency ↑
Video Diffusion (无动作)-25.3-
+ 动作拼接28.123.80.42
+ 动作条件 AdaIN27.522.50.51
VID2World31.218.70.78

4.2 动作-响应因果性

评估模型是否真正学习到动作→结果的因果关系:

测试动作一致率物理合理性时序因果率
简单动作92.3%87.5%91.2%
组合动作78.6%82.1%76.8%
长期预测65.2%71.4%58.3%

5. 与其他方法的对比

方法交互能力视频质量泛化能力
Video Diffusion
DreamerVLA⚠️⚠️
VID2World

核心优势:VID2World 利用预训练视频扩散的质量,同时添加了交互能力。

6. 应用场景

6.1 机器人仿真

  • 从机器人视角视频预测动作效果
  • 用于 sim-to-real 迁移

6.2 自动驾驶规划

  • 预测不同驾驶动作的后果
  • 支持安全规划

6.3 游戏引擎

  • 可交互的视频游戏世界
  • 响应玩家动作

7. 总结

VID2World 提供了将视频扩散模型转化为世界模型的系统方法:

  1. 动作表示:统一的动作编码
  2. 交叉注意力:动作-视频融合
  3. 两阶段训练:冻结→微调

参考资料

  • VID2World: Crafting Video Diffusion Models to Interactive World Models (ICLR 2026)

Footnotes

  1. VID2World: Crafting Video Diffusion Models to Interactive World Models (ICLR 2026)