世界模型架构
世界模型的架构设计决定了其表示能力、计算效率和泛化性能。本文档详细分析各组件的架构选择和设计原则。
整体架构分类
┌─────────────────────────────────────────────────────────────────┐
│ 世界模型架构分类 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ 潜在世界模型 │ │ 观测世界模型 │ │
│ │ Latent World │ │ Observational │ │
│ │ Model │ │ World Model │ │
│ ├─────────────────┤ ├─────────────────┤ │
│ │ Dreamer 系列 │ │ DreamerV3 │ │
│ │ SimPLe │ │ MuZero Real │ │
│ │ MERLIN │ │ TD-MPC │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ 视频生成模型 │ │ 多模态世界模型 │ │
│ │ Video Gen │ │ Multimodal │ │
│ │ as World Model │ │ World Model │ │
│ ├─────────────────┤ ├─────────────────┤ │
│ │ Sora/Genie │ │ Vision-Lang- │ │
│ │ Stable Video │ │ Action Models │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
1. 编码器架构 (Encoder)
编码器负责将高维观测压缩为紧凑的潜在表示。
1.1 卷积神经网络 (CNN)
适用场景: 图像/视频帧输入
class CNNEncoder(nn.Module):
def __init__(self, obs_shape, latent_dim):
super().__init__()
# Atari (84x84x4) -> latent_dim
self.conv = nn.Sequential(
nn.Conv2d(obs_shape[0], 32, 4, stride=2), # 42x42
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2), # 20x20
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2), # 9x9
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2), # 3x3
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
dummy = torch.zeros(1, *obs_shape)
n_flatten = self.conv(dummy).shape[1]
self.fc = nn.Linear(n_flatten, latent_dim)
def forward(self, x):
x = x / 255.0 # 归一化
h = self.conv(x)
return self.fc(h)设计要点:
- 层数通常 3-4 层
- 步长为 2 的卷积替代池化
- 最终通过全连接层映射到潜在维度
1.2 Vision Transformer (ViT)
适用场景: 高分辨率图像、可变长度序列
class ViTEncoder(nn.Module):
def __init__(self, image_size=224, patch_size=16, latent_dim=512):
super().__init__()
n_patches = (image_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, 768, patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, 768))
self.blocks = nn.ModuleList([
TransformerBlock(768, 12, 3072)
for _ in range(12)
])
self.norm = nn.LayerNorm(768)
self.proj = nn.Linear(768, latent_dim)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, 768, n_patches_h, n_patches_w)
x = x.flatten(2).transpose(1, 2) # (B, n_patches, 768)
x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
x = self.norm(x[:, 0]) # CLS token
return self.proj(x)1.3 变分自编码器 (VAE)
用于学习随机潜在表示:
class VAEEncoder(nn.Module):
def __init__(self, encoder, latent_dim):
super().__init__()
self.encoder = encoder
self.fc_mu = nn.Linear(encoder.out_dim, latent_dim)
self.fc_logvar = nn.Linear(encoder.out_dim, latent_dim)
def forward(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
# 重参数化
std = torch.exp(0.5 * logvar)
z = mu + std * torch.randn_like(std)
return z, mu, logvar
def kl_loss(self, mu, logvar):
return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())2. 动态模型架构 (Dynamics Model)
动态模型是预测下一个潜在状态的核心组件。
2.1 RSSM (Recurrent State Space Model)
Dreamer 系列的核心组件:
class RSSM(nn.Module):
def __init__(self, action_dim, latent_dim, hidden_dim=512):
super().__init__()
self.latent_dim = latent_dim
# 先验网络: 预测下一个状态
self.prior_net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.GRU(hidden_dim, hidden_dim, num_layers=2),
nn.Linear(hidden_dim, latent_dim * 2) # mean, std
)
# 后验网络: 基于观测更新
self.posterior_net = nn.Sequential(
nn.Linear(latent_dim + action_dim + obs_dim, hidden_dim),
nn.GRU(hidden_dim, hidden_dim, num_layers=2),
nn.Linear(hidden_dim, latent_dim * 2)
)
def forward(self, z_prev, action, obs=None):
# 先验
prior_input = torch.cat([z_prev, action], dim=-1)
prior_h = self.prior_net(prior_input)
prior_mean, prior_logvar = prior_h.chunk(2, dim=-1)
prior_std = torch.exp(0.5 * prior_logvar)
z_prior = prior_mean + prior_std * torch.randn_like(prior_std)
if obs is not None:
# 后验(使用观测信息)
posterior_input = torch.cat([z_prev, action, obs], dim=-1)
posterior_h = self.posterior_net(posterior_input)
posterior_mean, posterior_logvar = posterior_h.chunk(2, dim=-1)
posterior_std = torch.exp(0.5 * posterior_logvar)
z_posterior = posterior_mean + posterior_std * torch.randn_like(posterior_std)
return z_posterior, (prior_mean, prior_logvar), (posterior_mean, posterior_logvar)
return z_prior, (prior_mean, prior_logvar), None2.2 Transformer 动态模型
处理长程依赖:
class TransformerDynamics(nn.Module):
def __init__(self, latent_dim, action_dim, n_heads=8, n_layers=4):
super().__init__()
self.temporal_pos_encoding = TemporalPositionalEncoding(latent_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=latent_dim,
nhead=n_heads,
dim_feedforward=latent_dim * 4,
dropout=0.1,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 动作编码
self.action_embed = nn.Linear(action_dim, latent_dim)
# 预测头
self.predict_head = nn.Sequential(
nn.Linear(latent_dim, latent_dim * 2),
nn.GELU(),
nn.Linear(latent_dim * 2, latent_dim * 2) # mean, logvar
)
def forward(self, z_sequence, actions):
"""
Args:
z_sequence: (B, T, latent_dim) 历史潜在状态
actions: (B, T, action_dim) 历史动作
"""
# 添加动作信息
action_emb = self.action_embed(actions)
x = z_sequence + action_emb
# 添加位置编码
x = self.temporal_pos_encoding(x)
# Transformer 处理
h = self.transformer(x)
# 预测下一个状态
next_h = h[:, -1] # 最后一步的表示
pred = self.predict_head(next_h)
mean, logvar = pred.chunk(2, dim=-1)
return mean, logvar2.3 SSM (状态空间模型)
Mamba/RetNet 等高效架构:
class SSMDynamics(nn.Module):
def __init__(self, latent_dim, state_dim=16):
super().__init__()
self.latent_dim = latent_dim
self.state_dim = state_dim
# SSM 参数
self.A = nn.Parameter(torch.randn(state_dim, state_dim))
self.B = nn.Linear(latent_dim, state_dim)
self.C = nn.Linear(state_dim, latent_dim)
self.D = nn.Linear(latent_dim, latent_dim) # 直通
# 选择性机制 (Mamba)
self.x_proj = nn.Linear(latent_dim, state_dim * 2)
def forward(self, z_prev, action):
"""
选择性 SSM 前向传播
"""
u = z_prev + action # 组合输入
# 计算选择性参数
x = self.x_proj(u)
B_t = self.B(u)
# SSM 离散化
# dt = sigmoid(projection)
# B = B_t * dt
# 状态更新: h' = A @ h + B @ u
h = torch.zeros(u.shape[0], self.state_dim, device=u.device)
# ... 状态更新逻辑
y = self.C(h) + self.D(u)
return y3. 奖励预测器 (Reward Predictor)
3.1 简单 MLP
class RewardPredictor(nn.Module):
def __init__(self, latent_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
def forward(self, z, action):
x = torch.cat([z, action], dim=-1)
return self.net(x)3.2 带注意力机制
class AttentionRewardPredictor(nn.Module):
def __init__(self, latent_dim, action_dim):
super().__init__()
self.query = nn.Linear(latent_dim, latent_dim)
self.key = nn.Linear(action_dim, latent_dim)
self.value = nn.Linear(action_dim, latent_dim)
self.net = nn.Sequential(
nn.Linear(latent_dim * 2, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
def forward(self, z_sequence, current_action):
"""
Args:
z_sequence: (B, T, latent_dim)
current_action: (B, action_dim) 或 (B, T, action_dim)
"""
if current_action.dim() == 2:
current_action = current_action.unsqueeze(1).expand(-1, z_sequence.shape[1], -1)
q = self.query(z_sequence) # (B, T, latent)
k = self.key(current_action) # (B, T, latent)
v = self.value(current_action) # (B, T, latent)
# 注意力权重
attn = torch.softmax(q @ k.transpose(-2, -1) / (latent_dim ** 0.5), dim=-1)
context = attn @ v # (B, T, latent)
# 使用最后一个状态预测奖励
z_last = z_sequence[:, -1]
context_last = context[:, -1]
x = torch.cat([z_last, context_last], dim=-1)
return self.net(x)4. 策略网络 (Policy)
4.1 连续动作:Gaussian 策略
class GaussianPolicy(nn.Module):
def __init__(self, latent_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.mean_net = nn.Linear(hidden_dim, action_dim)
self.log_std_net = nn.Linear(hidden_dim, action_dim)
# 初始化
nn.init.orthogonal_(self.mean_net.weight, gain=0.01)
nn.init.orthogonal_(self.log_std_net.weight, gain=0.01)
def forward(self, z):
h = self.network(z)
mean = self.mean_net(h)
log_std = torch.tanh(self.log_std_net(h)) # 限制在 [-1, 1]
log_std = log_std * 2 - 1 # scale to [-3, 1] roughly
std = torch.exp(log_std)
return mean, std
def sample(self, z):
mean, std = self(z)
dist = torch.distributions.Normal(mean, std)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(dim=-1)
# Squashing (Tanh)
action = torch.tanh(action)
log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=-1)
return action, log_prob4.2 离散动作:Categorical 策略
class CategoricalPolicy(nn.Module):
def __init__(self, latent_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, z):
logits = self.network(z)
return logits
def sample(self, z):
logits = self(z)
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
return action, log_prob5. 价值网络 (Value Function)
5.1 GRU 价值网络
class ValueNetwork(nn.Module):
def __init__(self, latent_dim, hidden_dim=256):
super().__init__()
self.gru = nn.GRU(latent_dim, hidden_dim, num_layers=2)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, z_sequence):
"""
Args:
z_sequence: (B, T, latent_dim)
Returns:
values: (B, T)
"""
# GRU 接受 (T, B, hidden)
h, _ = self.gru(z_sequence.transpose(0, 1))
values = self.value_head(h).squeeze(-1) # (T, B)
return values.transpose(0, 1) # (B, T)5.2 Transformer 价值网络
class TransformerValueNetwork(nn.Module):
def __init__(self, latent_dim, n_heads=4, n_layers=2):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=latent_dim,
nhead=n_heads,
dim_feedforward=latent_dim * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
self.value_head = nn.Sequential(
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, 1)
)
def forward(self, z_sequence):
h = self.transformer(z_sequence)
# 使用最后状态的价值估计
value = self.value_head(h[:, -1])
return value.squeeze(-1)6. 完整架构示例
┌─────────────────────────────────────────────────────────────┐
│ 世界模型完整架构 │
│ │
│ ┌─────────────┐ │
│ │ 观测 o │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Encoder │────▶│ 潜在 │ │
│ │ E(o) → z │ │ 状态 z │ │
│ └─────────────┘ └──────┬──────┘ │
│ │ │
│ ┌────────────────────┼────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Decoder │ │ Dynamics │ │ Reward │ │
│ │ D(z) → o │ │ p(z'|z,a) │ │ R(z,a) │ │
│ └─────────────┘ └──────┬──────┘ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Policy │ │
│ │ π(a|z) │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Value │ │
│ │ V(z) │ │
│ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
架构选择指南
| 场景 | 编码器 | 动态模型 | 策略 |
|---|---|---|---|
| Atari (84x84) | CNN | RSSM (GRU) | Gaussian |
| 连续控制 | CNN/MLP | RSSM (GRU) | Gaussian |
| 高分辨率 | ViT | Transformer | 任务相关 |
| 长序列规划 | CNN | SSM/Mamba | 任务相关 |
| 多模态 | 多模态编码器 | Transformer | 多模态策略 |
设计原则
- 信息瓶颈:潜在空间应足够紧凑以促进泛化
- 可逆性:编码器和解码器应尽量可逆
- 不确定性估计:动态模型应能估计其不确定性
- 计算效率:考虑训练和推理的计算成本
- 端到端训练:各组件应能联合优化