Astra自回归去噪世界模型
概述
Astra是由清华大学与快手科技联合提出的自回归去噪世界模型。该模型结合了扩散模型的生成能力和自回归建模的长距离依赖捕捉能力,实现了高质量的视频预测和动作条件生成。
Astra核心思想
为什么选择自回归+去噪
| 方法 | 优点 | 缺点 |
|---|
| 自回归视频模型 | 长依赖建模强 | 计算量大 |
| 扩散视频模型 | 生成质量高 | 难以条件控制 |
| Astra | 两者兼顾 | 训练复杂 |
关键创新
- 去噪自回归:每步预测都是去噪过程
- 动作条件注入:将动作信息有效融入去噪过程
- 高效架构:平衡质量和效率
技术架构
自回归去噪框架
class Astra:
def __init__(self):
self.denoiser = DenoisingTransformer()
self.action_encoder = ActionConditionEncoder()
self.temporal_encoder = TemporalTransformer()
def forward(self, noisy_video, timestep, action_embed, context):
"""
Astra前向过程
"""
# 1. 时间编码
temporal_features = self.temporal_encoder(context)
# 2. 动作条件编码
action_features = self.action_encoder(action_embed)
# 3. 去噪预测
denoised = self.denoiser(
noisy_video,
timestep,
temporal_features,
action_features
)
return denoised
去噪自回归机制
class DenoisingAutoregressiveStep:
def __init__(self, astra_model):
self.model = astra_model
def step(self, current_video, action, num_denoise_steps=50):
"""
自回归去噪一步
"""
# 添加少量噪声作为自回归信号
noisy_next = add_light_noise(current_video, sigma=0.01)
# 去噪得到预测的下一帧
for t in reversed(range(num_denoise_steps)):
denoised = self.model.denoise(
noisy_next,
timestep=t,
action_embed=action
)
return denoised
动作条件编码
class ActionConditionEncoder:
def __init__(self, action_dim, embed_dim):
self.action_embedding = nn.Linear(action_dim, embed_dim)
self.temporal_conv = nn.Conv1d(embed_dim, embed_dim, 3, padding=1)
def encode(self, action_sequence):
"""
编码动作序列
"""
# 初始嵌入
embedded = self.action_embedding(action_sequence)
# 时间卷积
temporal_features = self.temporal_conv(embedded.transpose(1, 2))
return temporal_features.transpose(1, 2)
视频预测能力
多步预测
def multi_step_prediction(
model,
initial_frames,
action_sequence,
prediction_horizon
):
"""
多步视频预测
"""
predictions = []
current = initial_frames
for step in range(prediction_horizon):
# 获取当前动作
action = action_sequence[step]
# 自回归预测下一步
next_frame = model.predict_next(
current,
action
)
predictions.append(next_frame)
# 更新当前状态(用于下一步)
current = torch.cat([current[:, 1:], next_frame], dim=1)
return predictions
条件生成
def conditional_video_generation(
model,
initial_frame,
action_trajectory,
style=None
):
"""
条件视频生成
"""
# 编码条件
action_features = model.action_encoder(action_trajectory)
if style:
style_features = model.style_encoder(style)
condition = combine_conditions(action_features, style_features)
else:
condition = action_features
# 生成
generated_video = model.generate(
initial_frame,
condition
)
return generated_video
应用场景
自动驾驶视频预测
class AutonomousDrivingPrediction:
def __init__(self):
self.model = Astra()
self.vehicle_controller = VehicleController()
def predict_trajectory(
self,
current_video,
planned_trajectory
):
"""
预测自动驾驶轨迹
"""
predictions = []
current = current_video
for waypoint in planned_trajectory:
# 控制器生成动作
action = self.vehicle_controller.compute_action(
current,
waypoint
)
# 世界模型预测
next_frame = self.model.predict_next(current, action)
predictions.append(next_frame)
current = torch.cat([current[:, 1:], next_frame], dim=1)
return predictions
机器人控制
class RobotWorldModel:
def __init__(self):
self.astras = Astra()
self.planner = MotionPlanner()
def imagine_motion(
self,
initial_state,
goal_state,
robot_urdf
):
"""
想象机器人运动
"""
# 规划动作序列
action_plan = self.planner.plan(
initial_state,
goal_state,
robot_urdf
)
# 自回归预测
video_predictions = []
current = initial_state
for action in action_plan:
next_state = self.astras.predict_next(current, action)
video_predictions.append(next_state)
current = next_state
return video_predictions
世界探索
class WorldExploration:
def __init__(self):
self.world_model = Astra()
self.explorer = FrontierExplorer()
def explore(
self,
current_observation,
exploration_policy
):
"""
世界探索
"""
while not exploration_policy.is_complete():
# 选择探索动作
action = exploration_policy.select_action(
current_observation
)
# 预测探索结果
predicted_observation = self.world_model.predict_next(
current_observation,
action
)
# 评估探索价值
value = exploration_policy.evaluate(
predicted_observation
)
yield {
'action': action,
'predicted': predicted_observation,
'value': value
}
current_observation = predicted_observation
实验评估
视频预测质量
| 指标 | Astra | 竞争方法 |
|---|
| FVD ↓ | 95 | 120 |
| SSIM ↑ | 0.94 | 0.89 |
| LPIPS ↓ | 0.06 | 0.10 |
动作跟随精度
| 任务 | 精度 | 平滑度 |
|---|
| 轨迹跟随 | 92% | 0.94 |
| 目标到达 | 88% | - |
| 避障 | 95% | 0.91 |
技术细节
模型配置
# Astra模型配置
config = {
'model': {
'hidden_dim': 1024,
'num_layers': 24,
'num_heads': 16,
'temporal_window': 16,
},
'training': {
'learning_rate': 1e-4,
'batch_size': 16,
'num_gpus': 32,
'warmup_steps': 5000
},
'diffusion': {
'num_timesteps': 1000,
'beta_schedule': 'linear',
'snr_gamma': 0.5
}
}
训练策略
def train_astras(dataset):
"""Astra训练"""
for batch in dataset:
video, actions = batch
# 1. 采样时间步
t = sample_timestep(batch_size)
# 2. 添加噪声
noise = sample_noise(video)
noisy_video = add_noise(video, noise, t)
# 3. 预测噪声
pred_noise = astra(
noisy_video,
timestep=t,
actions=actions
)
# 4. 损失
loss = snr_weighted_mse(pred_noise, noise, t)
loss.backward()
optimizer.step()
局限性与改进
当前局限
- 计算效率:自回归步骤较多
- 长期预测:误差累积问题
- 物理约束:复杂物理模拟不足
未来方向
- 并行去噪:减少自回归步骤
- 层次化预测:从粗到细的生成
- 物理集成:结合物理引擎
参考论文
相关资源