世界模型架构

世界模型的架构设计决定了其表示能力、计算效率和泛化性能。本文档详细分析各组件的架构选择和设计原则。

整体架构分类

┌─────────────────────────────────────────────────────────────────┐
│                    世界模型架构分类                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   ┌─────────────────┐    ┌─────────────────┐                    │
│   │  潜在世界模型    │    │   观测世界模型   │                    │
│   │  Latent World   │    │  Observational  │                    │
│   │  Model          │    │  World Model    │                    │
│   ├─────────────────┤    ├─────────────────┤                    │
│   │  Dreamer 系列   │    │  DreamerV3      │                    │
│   │  SimPLe        │    │  MuZero Real    │                    │
│   │  MERLIN         │    │  TD-MPC         │                    │
│   └─────────────────┘    └─────────────────┘                    │
│                                                                  │
│   ┌─────────────────┐    ┌─────────────────┐                    │
│   │ 视频生成模型     │    │  多模态世界模型   │                    │
│   │  Video Gen      │    │ Multimodal      │                    │
│   │  as World Model │    │ World Model     │                    │
│   ├─────────────────┤    ├─────────────────┤                    │
│   │  Sora/Genie    │    │  Vision-Lang-   │                    │
│   │  Stable Video  │    │  Action Models  │                    │
│   └─────────────────┘    └─────────────────┘                    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

1. 编码器架构 (Encoder)

编码器负责将高维观测压缩为紧凑的潜在表示。

1.1 卷积神经网络 (CNN)

适用场景: 图像/视频帧输入

class CNNEncoder(nn.Module):
    def __init__(self, obs_shape, latent_dim):
        super().__init__()
        # Atari (84x84x4) -> latent_dim
        self.conv = nn.Sequential(
            nn.Conv2d(obs_shape[0], 32, 4, stride=2),  # 42x42
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),             # 20x20
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),            # 9x9
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2),           # 3x3
            nn.ReLU(),
            nn.Flatten(),
        )
        
        with torch.no_grad():
            dummy = torch.zeros(1, *obs_shape)
            n_flatten = self.conv(dummy).shape[1]
        
        self.fc = nn.Linear(n_flatten, latent_dim)
        
    def forward(self, x):
        x = x / 255.0  # 归一化
        h = self.conv(x)
        return self.fc(h)

设计要点:

  • 层数通常 3-4 层
  • 步长为 2 的卷积替代池化
  • 最终通过全连接层映射到潜在维度

1.2 Vision Transformer (ViT)

适用场景: 高分辨率图像、可变长度序列

class ViTEncoder(nn.Module):
    def __init__(self, image_size=224, patch_size=16, latent_dim=512):
        super().__init__()
        n_patches = (image_size // patch_size) ** 2
        
        self.patch_embed = nn.Conv2d(3, 768, patch_size, stride=patch_size)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, 768))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(768, 12, 3072) 
            for _ in range(12)
        ])
        
        self.norm = nn.LayerNorm(768)
        self.proj = nn.Linear(768, latent_dim)
        
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, 768, n_patches_h, n_patches_w)
        x = x.flatten(2).transpose(1, 2)  # (B, n_patches, 768)
        
        x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1)
        x = x + self.pos_embed
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x[:, 0])  # CLS token
        return self.proj(x)

1.3 变分自编码器 (VAE)

用于学习随机潜在表示:

class VAEEncoder(nn.Module):
    def __init__(self, encoder, latent_dim):
        super().__init__()
        self.encoder = encoder
        self.fc_mu = nn.Linear(encoder.out_dim, latent_dim)
        self.fc_logvar = nn.Linear(encoder.out_dim, latent_dim)
        
    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        
        # 重参数化
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        
        return z, mu, logvar
    
    def kl_loss(self, mu, logvar):
        return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

2. 动态模型架构 (Dynamics Model)

动态模型是预测下一个潜在状态的核心组件。

2.1 RSSM (Recurrent State Space Model)

Dreamer 系列的核心组件:

class RSSM(nn.Module):
    def __init__(self, action_dim, latent_dim, hidden_dim=512):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 先验网络: 预测下一个状态
        self.prior_net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.GRU(hidden_dim, hidden_dim, num_layers=2),
            nn.Linear(hidden_dim, latent_dim * 2)  # mean, std
        )
        
        # 后验网络: 基于观测更新
        self.posterior_net = nn.Sequential(
            nn.Linear(latent_dim + action_dim + obs_dim, hidden_dim),
            nn.GRU(hidden_dim, hidden_dim, num_layers=2),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        
    def forward(self, z_prev, action, obs=None):
        # 先验
        prior_input = torch.cat([z_prev, action], dim=-1)
        prior_h = self.prior_net(prior_input)
        prior_mean, prior_logvar = prior_h.chunk(2, dim=-1)
        prior_std = torch.exp(0.5 * prior_logvar)
        z_prior = prior_mean + prior_std * torch.randn_like(prior_std)
        
        if obs is not None:
            # 后验(使用观测信息)
            posterior_input = torch.cat([z_prev, action, obs], dim=-1)
            posterior_h = self.posterior_net(posterior_input)
            posterior_mean, posterior_logvar = posterior_h.chunk(2, dim=-1)
            posterior_std = torch.exp(0.5 * posterior_logvar)
            z_posterior = posterior_mean + posterior_std * torch.randn_like(posterior_std)
            return z_posterior, (prior_mean, prior_logvar), (posterior_mean, posterior_logvar)
        
        return z_prior, (prior_mean, prior_logvar), None

2.2 Transformer 动态模型

处理长程依赖:

class TransformerDynamics(nn.Module):
    def __init__(self, latent_dim, action_dim, n_heads=8, n_layers=4):
        super().__init__()
        
        self.temporal_pos_encoding = TemporalPositionalEncoding(latent_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=n_heads,
            dim_feedforward=latent_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # 动作编码
        self.action_embed = nn.Linear(action_dim, latent_dim)
        
        # 预测头
        self.predict_head = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 2),
            nn.GELU(),
            nn.Linear(latent_dim * 2, latent_dim * 2)  # mean, logvar
        )
        
    def forward(self, z_sequence, actions):
        """
        Args:
            z_sequence: (B, T, latent_dim) 历史潜在状态
            actions: (B, T, action_dim) 历史动作
        """
        # 添加动作信息
        action_emb = self.action_embed(actions)
        x = z_sequence + action_emb
        
        # 添加位置编码
        x = self.temporal_pos_encoding(x)
        
        # Transformer 处理
        h = self.transformer(x)
        
        # 预测下一个状态
        next_h = h[:, -1]  # 最后一步的表示
        pred = self.predict_head(next_h)
        mean, logvar = pred.chunk(2, dim=-1)
        
        return mean, logvar

2.3 SSM (状态空间模型)

Mamba/RetNet 等高效架构:

class SSMDynamics(nn.Module):
    def __init__(self, latent_dim, state_dim=16):
        super().__init__()
        self.latent_dim = latent_dim
        self.state_dim = state_dim
        
        # SSM 参数
        self.A = nn.Parameter(torch.randn(state_dim, state_dim))
        self.B = nn.Linear(latent_dim, state_dim)
        self.C = nn.Linear(state_dim, latent_dim)
        self.D = nn.Linear(latent_dim, latent_dim)  # 直通
        
        # 选择性机制 (Mamba)
        self.x_proj = nn.Linear(latent_dim, state_dim * 2)
        
    def forward(self, z_prev, action):
        """
        选择性 SSM 前向传播
        """
        u = z_prev + action  # 组合输入
        
        # 计算选择性参数
        x = self.x_proj(u)
        B_t = self.B(u)
        
        # SSM 离散化
        # dt = sigmoid(projection)
        # B = B_t * dt
        # 状态更新: h' = A @ h + B @ u
        
        h = torch.zeros(u.shape[0], self.state_dim, device=u.device)
        # ... 状态更新逻辑
        
        y = self.C(h) + self.D(u)
        return y

3. 奖励预测器 (Reward Predictor)

3.1 简单 MLP

class RewardPredictor(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self, z, action):
        x = torch.cat([z, action], dim=-1)
        return self.net(x)

3.2 带注意力机制

class AttentionRewardPredictor(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.query = nn.Linear(latent_dim, latent_dim)
        self.key = nn.Linear(action_dim, latent_dim)
        self.value = nn.Linear(action_dim, latent_dim)
        
        self.net = nn.Sequential(
            nn.Linear(latent_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self, z_sequence, current_action):
        """
        Args:
            z_sequence: (B, T, latent_dim)
            current_action: (B, action_dim) 或 (B, T, action_dim)
        """
        if current_action.dim() == 2:
            current_action = current_action.unsqueeze(1).expand(-1, z_sequence.shape[1], -1)
        
        q = self.query(z_sequence)  # (B, T, latent)
        k = self.key(current_action)  # (B, T, latent)
        v = self.value(current_action)  # (B, T, latent)
        
        # 注意力权重
        attn = torch.softmax(q @ k.transpose(-2, -1) / (latent_dim ** 0.5), dim=-1)
        context = attn @ v  # (B, T, latent)
        
        # 使用最后一个状态预测奖励
        z_last = z_sequence[:, -1]
        context_last = context[:, -1]
        
        x = torch.cat([z_last, context_last], dim=-1)
        return self.net(x)

4. 策略网络 (Policy)

4.1 连续动作:Gaussian 策略

class GaussianPolicy(nn.Module):
    def __init__(self, latent_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        
        self.mean_net = nn.Linear(hidden_dim, action_dim)
        self.log_std_net = nn.Linear(hidden_dim, action_dim)
        
        # 初始化
        nn.init.orthogonal_(self.mean_net.weight, gain=0.01)
        nn.init.orthogonal_(self.log_std_net.weight, gain=0.01)
        
    def forward(self, z):
        h = self.network(z)
        mean = self.mean_net(h)
        log_std = torch.tanh(self.log_std_net(h))  # 限制在 [-1, 1]
        log_std = log_std * 2 - 1  # scale to [-3, 1] roughly
        std = torch.exp(log_std)
        
        return mean, std
    
    def sample(self, z):
        mean, std = self(z)
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        
        # Squashing (Tanh)
        action = torch.tanh(action)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=-1)
        
        return action, log_prob

4.2 离散动作:Categorical 策略

class CategoricalPolicy(nn.Module):
    def __init__(self, latent_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
    def forward(self, z):
        logits = self.network(z)
        return logits
    
    def sample(self, z):
        logits = self(z)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action, log_prob

5. 价值网络 (Value Function)

5.1 GRU 价值网络

class ValueNetwork(nn.Module):
    def __init__(self, latent_dim, hidden_dim=256):
        super().__init__()
        self.gru = nn.GRU(latent_dim, hidden_dim, num_layers=2)
        self.value_head = nn.Linear(hidden_dim, 1)
        
    def forward(self, z_sequence):
        """
        Args:
            z_sequence: (B, T, latent_dim)
        Returns:
            values: (B, T)
        """
        # GRU 接受 (T, B, hidden)
        h, _ = self.gru(z_sequence.transpose(0, 1))
        values = self.value_head(h).squeeze(-1)  # (T, B)
        return values.transpose(0, 1)  # (B, T)

5.2 Transformer 价值网络

class TransformerValueNetwork(nn.Module):
    def __init__(self, latent_dim, n_heads=4, n_layers=2):
        super().__init__()
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=latent_dim,
            nhead=n_heads,
            dim_feedforward=latent_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        
        self.value_head = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, 1)
        )
        
    def forward(self, z_sequence):
        h = self.transformer(z_sequence)
        # 使用最后状态的价值估计
        value = self.value_head(h[:, -1])
        return value.squeeze(-1)

6. 完整架构示例

┌─────────────────────────────────────────────────────────────┐
│                  世界模型完整架构                            │
│                                                              │
│  ┌─────────────┐                                            │
│  │   观测 o    │                                            │
│  └──────┬──────┘                                            │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐     ┌─────────────┐                        │
│  │   Encoder   │────▶│    潜在     │                        │
│  │  E(o) → z  │     │   状态 z    │                        │
│  └─────────────┘     └──────┬──────┘                        │
│                              │                               │
│         ┌────────────────────┼────────────────────┐         │
│         │                    │                    │         │
│         ▼                    ▼                    ▼         │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐    │
│  │  Decoder    │     │  Dynamics   │     │   Reward    │    │
│  │  D(z) → o  │     │ p(z'|z,a)  │     │   R(z,a)    │    │
│  └─────────────┘     └──────┬──────┘     └─────────────┘    │
│                              │                               │
│                              ▼                               │
│                       ┌─────────────┐                        │
│                       │   Policy    │                        │
│                       │   π(a|z)   │                        │
│                       └──────┬──────┘                        │
│                              │                               │
│                              ▼                               │
│                       ┌─────────────┐                        │
│                       │   Value     │                        │
│                       │   V(z)     │                        │
│                       └─────────────┘                        │
└─────────────────────────────────────────────────────────────┘

架构选择指南

场景编码器动态模型策略
Atari (84x84)CNNRSSM (GRU)Gaussian
连续控制CNN/MLPRSSM (GRU)Gaussian
高分辨率ViTTransformer任务相关
长序列规划CNNSSM/Mamba任务相关
多模态多模态编码器Transformer多模态策略

设计原则

  1. 信息瓶颈:潜在空间应足够紧凑以促进泛化
  2. 可逆性:编码器和解码器应尽量可逆
  3. 不确定性估计:动态模型应能估计其不确定性
  4. 计算效率:考虑训练和推理的计算成本
  5. 端到端训练:各组件应能联合优化

相关主题