概述
DINO-world 是由 Meta FAIR 团队提出的一种通用视频世界模型,其核心创新在于利用预训练的 JEPA 图像编码器(DINOv2)来预测未来帧的潜在表示。1 与传统的像素级视频生成模型不同,DINO-world 在特征空间中进行预测,从而能够学习到更加语义化和几何化的世界动态。
该模型解决了世界模型训练中的三个核心挑战:数据问题(大规模无标注视频)、计算效率问题(避免像素级建模的巨额开销)、以及泛化问题(跨领域零样本迁移)。通过在大规模网络视频数据集(约6000万视频)上进行预训练,DINO-world 学习到了来自驾驶场景、室内场景到模拟环境的丰富时空动态。
1. 研究背景与动机
1.1 世界模型的定义与挑战
世界模型(World Model)最早由 Ha 和 Schmidhuber 在2018年提出1,定义为:给定过去观察和智能体动作,预测环境未来状态的神经网络。这一概念与世界模型与规划推理融合的研究密切相关,因为世界模型是实现自主智能体规划与控制的关键组件。
然而,训练有效的世界模型面临三重挑战:
| 挑战 | 描述 | 影响 |
|---|---|---|
| 数据问题 | 高质量视频数据稀缺,带动作标注的数据更少 | 当前世界模型局限于特定领域(如自动驾驶、游戏) |
| 计算问题 | 像素级生成模型资源消耗巨大(如COSMOS需要22M GPU小时) | 训练成本高昂,难以规模化 |
| 泛化问题 | 准确建模物理世界行为仍是开放问题 | 即使短时预测也面临困难 |
1.2 为什么选择潜在空间?
传统视频生成模型(如 SORA、COSMOS)通常在像素空间进行预测,这虽然能生成高保真视频,但计算代价高昂,且模型容量可能浪费在预测无关细节上(如自动驾驶场景中每片树叶的精确运动)。
DINO-world 的核心洞察是:下游任务(如分割、深度估计、规划)实际上只需要高层语义和几何特征,而不需要像素级的细节重建。因此,在预训练的视觉特征空间中进行世界建模是更高效的选择。
这一思想与 JEPA 架构的核心理念一脉相承——在潜在空间进行预测而非像素空间重建。
2. 方法论
2.1 系统架构
DINO-world 采用编码器-预测器架构,包含两个核心组件:
┌─────────────────────────────────────────────────────────────────────┐
│ DINO-world 架构 │
│ │
│ 输入视频帧 │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ DINOv2 │ 冻结的视觉编码器 │
│ │ Encoder │ 提取每帧的 patch 特征 │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────┐ │
│ │ 未来预测器 │ │
│ │ ┌─────────────────────────────────┐ │ │
│ │ │ Cross-Attention Blocks (N=40) │ │ │
│ │ │ • 查询: 可学习查询向量 │ │ │
│ │ │ • 键值: 历史 patch tokens │ │ │
│ │ │ • 3轴 RoPE 位置编码 │ │ │
│ │ └─────────────────────────────────┘ │ │
│ └───────────────────┬───────────────────┘ │
│ │ │
│ ▼ │
│ 预测的未来帧特征 │
│ │
└─────────────────────────────────────────────────────────────────────┘
数学形式化:给定视频帧序列 和对应时间戳,世界模型定义为映射:
其中 是历史 DINOv2 特征, 是时间戳集合, 指定要预测的未来时空位置。
2.2 帧编码器:DINOv2
DINOv2 是一种自监督视觉 transformer,通过知识蒸馏从大型 ViT 模型获取强大的视觉表示。与 VAE 等重建导向的编码器不同,DINOv2 专为表示学习设计,已在图像分类、分割、深度估计等多种任务上展现出色的泛化能力。
DINOv2 将每帧 编码为特征张量:
其中 (ViT-B/14), 是空间分辨率。
2.3 预测器架构
预测器是一个堆叠了 个残差预归一化交叉注意力块的 transformer 网络:
// 交叉注意力块的核心计算
q ← q + CROSS_ATTENTION(LN(q), {x_{t,i,j} | τ_t < τ_{t'}}) // (1)
q ← q + MLP(LN(q)) // (2)查询初始化:对于每个要预测的时空位置 ,从可学习嵌入初始化查询向量 。
3轴 RoPE 位置编码:将头维度 分为三部分,分别编码:
- 时间坐标 (绝对时间戳,单位为秒)
- 水平坐标 (相对位置 )
- 垂直坐标 (相对位置 )
这种设计使模型能够区分不同帧率和空间分辨率的输入。
2.4 训练目标
采用下一帧预测目标和教师强制(Teacher Forcing)策略:
其中 是 smooth L1 损失。通过块三角注意力掩码确保因果性:预测帧 只能 attend 到帧 到 的 tokens。
2.5 变帧率采样
为避免时间间隔分布偏斜(短间隔过多),训练时对每个视频采样 个时间间隔,均匀从 范围内抽取。这确保模型在均匀分布的时间间隔上进行训练,从而具备灵活的预测能力。
3. 动作条件微调
3.1 动作块机制
DINO-world 可以在动作标注数据上进行微调,添加动作块(Action Blocks)来实现动作条件预测:
// 动作块的核心计算
q ← q + MLP(LN([q, a_t])) // 查询与对应动作融合动作块初始化为恒等映射,允许在有限的动作数据上高效微调,同时保持预训练的世界理解能力。
3.2 冻结 vs 全量微调
| 微调策略 | 优点 | 适用场景 |
|---|---|---|
| 全量微调 | 充分利用动作数据 | 动作数据充足 |
| 冻结编码器 | 避免灾难性遗忘 | 动作数据稀缺、需多任务复用 |
这种灵活性使 DINO-world 能够适应不同的下游应用场景。
4. 完整 PyTorch 实现
#include <bits/stdc++.h>
using namespace std;
#include <torch/torch.h>
#include <torch/nn/modules/transformer.h>
namespace dino_world {
// ============================================================================
// 3轴旋转位置编码 (3-Axial RoPE)
// ============================================================================
class RotaryPositionalEncoding3D : public torch::nn::Module {
private:
int dim_heads_; // 每头维度
float min_period_; // RoPE 最小周期
float max_period_; // RoPE 最大周期
torch::Tensor inv_freq_; // 逆频率
public:
RotaryPositionalEncoding3D(int num_heads, int head_dim,
float min_period = 1e-2, float max_period = 1e2)
: dim_heads_(head_dim), min_period_(min_period), max_period_(max_period) {
// 计算逆频率: inv_freq = 1 / (period * 2pi)
// 周期范围 [min_period, max_period]
inv_freq_ = torch::exp(
torch::linspace(std::log(min_period_ * 2 * M_PI),
std::log(max_period_ * 2 * M_PI),
dim_heads_ / 3)
).to(torch::kFloat32);
register_buffer("inv_freq", inv_freq_);
}
// 计算旋转角度
torch::Tensor _compute_rotary(torch::Tensor freqs) {
// freqs: (seq_len, dim/3)
return torch::polar(torch.ones_like(freqs), freqs);
}
// 应用 RoPE 到 Q/K
pair<torch::Tensor, torch::Tensor> apply_rotary(
torch::Tensor q, // (batch, num_heads, seq_len, head_dim/3)
torch::Tensor k, // (batch, num_heads, seq_len, head_dim/3)
torch::Tensor seq_coords // (batch, seq_len, 3) - 时间、水平、垂直坐标
) {
auto batch = q.size(0);
auto seq_len = q.size(2);
// 提取各轴坐标
auto t_coords = seq_coords.index({torch::indexing::Slice(), torch::indexing::Slice(), 0}); // (batch, seq_len)
auto h_coords = seq_coords.index({torch::indexing::Slice(), torch::indexing::Slice(), 1});
auto w_coords = seq_coords.index({torch::indexing::Slice(), torch::indexing::Slice(), 2});
// 计算各轴频率
auto t_freqs = torch::einsum("bi,d->bid", t_coords, inv_freq_.slice(0, 0, dim_heads_/3));
auto h_freqs = torch::einsum("bi,d->bid", h_coords, inv_freq_.slice(0, dim_heads_/3, 2*dim_heads_/3));
auto w_freqs = torch::einsum("bi,d->bid", w_coords, inv_freq_.slice(0, 2*dim_heads_/3, dim_heads_));
// 3轴旋转
auto t_rot = _compute_rotary(t_freqs);
auto h_rot = _compute_rotary(h_freqs);
auto w_rot = _compute_rotary(w_freqs);
// 应用到查询和键
auto q_out = torch::view_as_complex(torch::stack({
q.slice(-1, 0, dim_heads_/3) * t_rot.cos() -
q.slice(-1, dim_heads_/3, 2*dim_heads_/3).roll(1, -1) * t_rot.sin().roll(1, -1),
q.slice(-1, 0, dim_heads_/3) * t_rot.sin().roll(1, -1) +
q.slice(-1, dim_heads_/3, 2*dim_heads_/3).roll(1, -1) * t_rot.cos().roll(1, -1)
}, -1)).squeeze(-1);
// ... 类似的处理用于其他两个轴 ...
return {q_out, k_out};
}
};
// ============================================================================
// 交叉注意力块 (Cross-Attention Block)
// ============================================================================
class CrossAttentionBlock : public torch::nn::Module {
private:
int dim_; // 模型维度
int num_heads_; // 注意力头数
int head_dim_; // 每头维度
float dropout_;
torch::nn::LayerNorm ln1_, ln2_; // 归一化层
torch::nn::Linear w_q_, w_k_, w_v_, w_o_; // QKV/O 投影
torch::nn::Linear mlp_; // FFN
shared_ptr<RotaryPositionalEncoding3D> rope_;
public:
CrossAttentionBlock(int dim, int num_heads, float dropout = 0.0,
shared_ptr<RotaryPositionalEncoding3D> rope = nullptr)
: dim_(dim), num_heads_(num_heads), head_dim_(dim / num_heads), dropout_(dropout),
rope_(rope) {
int head_dim_each = dim_ / num_heads_;
w_q_ = torch::nn::Linear(dim_, dim_);
w_k_ = torch::nn::Linear(dim_, dim_);
w_v_ = torch::nn::Linear(dim_, dim_);
w_o_ = torch::nn::Linear(dim_, dim_);
ln1_ = torch::nn::LayerNorm(torch::nn::LayerNormOptions(dim_));
ln2_ = torch::nn::LayerNorm(torch::nn::LayerNormOptions(dim_));
mlp_ = torch::nn::Linear(dim_, dim_ * 4);
register_module("ln1", ln1_);
register_module("ln2", ln2_);
register_module("w_q", w_q_);
register_module("w_k", w_k_);
register_module("w_v", w_v_);
register_module("w_o", w_o_);
register_module("mlp", mlp_);
}
torch::Tensor forward(torch::Tensor query, // (B, num_queries, D)
torch::Tensor key_value, // (B, seq_len, D)
torch::Tensor query_coords, // (B, num_queries, 3)
torch::Tensor kv_coords, // (B, seq_len, 3)
torch::Tensor attn_mask = {}) {
// 残差连接前的归一化
auto q = ln1_->forward(query);
auto kv = key_value;
// QKV 投影
q = w_q_->forward(q); // (B, num_queries, D)
k = w_k_->forward(kv); // (B, seq_len, D)
v = w_v_->forward(kv); // (B, seq_len, D)
// 重塑为多头格式
int B = q.size(0);
q = q.view({B, -1, num_heads_, head_dim_}).transpose(1, 2); // (B, H, Q, D_h)
k = k.view({B, -1, num_heads_, head_dim_}).transpose(1, 2);
v = v.view({B, -1, num_heads_, head_dim_}).transpose(1, 2);
// 应用 RoPE
if (rope_) {
auto [q_rot, k_rot] = rope_->apply_rotary(q, k, query_coords, kv_coords);
q = q_rot;
k = k_rot;
}
// 缩放
q = q / std::sqrt(head_dim_);
// 注意力计算
auto attn = torch::matmul(q, k.transpose(-2, -1)); // (B, H, Q, S)
if (attn_mask.defined()) {
attn = attn + attn_mask;
}
attn = torch::nn::functional::softmax(attn, -1);
attn = torch::nn::functional::dropout(attn, dropout_);
auto out = torch::matmul(attn, v); // (B, H, Q, D_h)
out = out.transpose(1, 2).contiguous().view({B, -1, dim_});
out = w_o_->forward(out);
// 第一个残差连接
query = query + out;
// MLP 块
auto mlp_out = mlp_->forward(torch::nn::functional::gelu(ln2_->forward(query)));
mlp_out = torch::nn::functional::dropout(mlp_out, dropout_);
query = query + mlp_out;
return query;
}
};
// ============================================================================
// 动作块 (Action Block)
// ============================================================================
class ActionBlock : public torch::nn::Module {
private:
torch::nn::Linear proj_;
torch::nn::Linear mlp_;
public:
ActionBlock(int dim, int action_dim) {
proj_ = torch::nn::Linear(dim + action_dim, dim);
mlp_ = torch::nn::Linear(dim, dim);
register_module("proj", proj_);
register_module("mlp", mlp_);
}
torch::Tensor forward(torch::Tensor query, torch::Tensor action) {
// 查询与动作拼接后投影
auto combined = torch::cat({query, action}, -1);
auto out = proj_->forward(combined);
out = torch::nn::functional::gelu(out);
out = mlp_->forward(out);
return query + out; // 残差连接
}
};
// ============================================================================
// DINO-world 完整模型
// ============================================================================
class DINOWorld : public torch::nn::Module {
private:
// 冻结的 DINOv2 编码器
torch::nn::ModuleHolder<torch::nn::TransformerEncoder> encoder_;
// 可学习的查询嵌入
torch::Tensor query_embed_; // (1, 1, D')
// 预测器块
vector<shared_ptr<CrossAttentionBlock>> predictor_blocks_;
torch::nn::LayerNorm final_ln_;
torch::nn::Linear output_proj_; // D' -> D
// 动作块 (可选)
vector<shared_ptr<ActionBlock>> action_blocks_;
bool use_actions_;
bool freeze_encoder_;
shared_ptr<RotaryPositionalEncoding3D> rope_;
public:
DINOWorld(
torch::nn::Module encoder, // 预训练 DINOv2
int predictor_dim = 1536,
int num_blocks = 40,
int num_heads = 24,
float dropout = 0.0,
int action_dim = 0,
bool use_actions = false,
bool freeze_encoder = true
) : use_actions_(use_actions), freeze_encoder_(freeze_encoder) {
// 使用 ViT-B/14: D=768, 这里简化处理
int dino_dim = 768;
// 冻结编码器
if (freeze_encoder_) {
encoder_->eval();
for (auto& p : encoder_->parameters()) {
p.set_requires_grad(false);
}
}
// 初始化 RoPE
rope_ = make_shared<RotaryPositionalEncoding3D>(
num_heads, predictor_dim / num_heads
);
// 可学习查询嵌入
query_embed_ = torch::randn({1, 1, predictor_dim}) * 0.02;
register_parameter("query_embed", query_embed_);
// 预测器块
for (int i = 0; i < num_blocks; ++i) {
predictor_blocks_.push_back(
make_shared<CrossAttentionBlock>(
predictor_dim, num_heads, dropout, rope_
)
);
char name[64];
snprintf(name, sizeof(name), "predictor_block_%d", i);
register_module(name, predictor_blocks_.back());
}
final_ln_ = torch::nn::LayerNorm(
torch::nn::LayerNormOptions(predictor_dim)
);
output_proj_ = torch::nn::Linear(predictor_dim, dino_dim);
register_module("final_ln", final_ln_);
register_module("output_proj", output_proj_);
// 动作块
if (use_actions_) {
for (int i = 0; i < num_blocks; ++i) {
action_blocks_.push_back(
make_shared<ActionBlock>(predictor_dim, action_dim)
);
char name[64];
snprintf(name, sizeof(name), "action_block_%d", i);
register_module(name, action_blocks_.back());
}
}
}
// 编码单帧
torch::Tensor encode_frame(torch::Tensor frame) {
// frame: (B, C, H, W)
// 返回: (B, H*W, D) patch tokens
auto features = encoder_->forward(frame); // 假设返回 patch tokens
return features;
}
// 批量编码多帧
pair<torch::Tensor, torch::Tensor> encode_video(
torch::Tensor frames, // (B, T, C, H, W)
torch::Tensor timestamps // (B, T)
) {
int B = frames.size(0);
int T = frames.size(1);
// 编码所有帧
vector<torch::Tensor> all_features;
vector<torch::Tensor> all_coords;
for (int t = 0; t < T; ++t) {
auto frame_t = frames.index({torch::indexing::Slice(), t});
auto feat_t = encode_frame(frame_t); // (B, H*W, D)
// 收集时间和空间坐标
auto coords_t = _create_coords(timestamps, feat_t.size(1)); // (B, H*W, 3)
all_features.push_back(feat_t);
all_coords.push_back(coords_t);
}
auto all_tokens = torch::cat(all_features, 1); // (B, T*H*W, D)
auto all_kv_coords = torch::cat(all_coords, 1); // (B, T*H*W, 3)
return {all_tokens, all_kv_coords};
}
// 创建时空坐标
torch::Tensor _create_coords(torch::Tensor timestamps, int num_patches) {
int B = timestamps.size(0);
// 简化: 使用归一化空间坐标 [-1, 1]
auto coords = torch::linspace(-1, 1, (int)std::sqrt(num_patches));
// 扩展到 batch 维度
coords = coords.unsqueeze(0).expand({B, num_patches});
return coords; // 返回 (B, num_patches, 1) - 实际需要3轴坐标
}
// 预测未来帧特征
torch::Tensor predict_future(
torch::Tensor context_tokens, // (B, ctx_len, D)
torch::Tensor context_coords, // (B, ctx_len, 3)
torch::Tensor target_coords, // (B, num_targets, 3) - 目标时空位置
torch::Tensor actions = {} // (B, ctx_len, action_dim) - 可选
) {
int B = context_tokens.size(0);
int num_targets = target_coords.size(1);
// 初始化查询
auto query = query_embed_.expand({B, num_targets, -1}); // (B, num_targets, D')
// 通过所有预测器块
for (int i = 0; i < predictor_blocks_.size(); ++i) {
// 交叉注意力
query = predictor_blocks_[i]->forward(
query, context_tokens, target_coords, context_coords
);
// 动作条件 (可选)
if (use_actions_ && actions.defined() && i < action_blocks_.size()) {
auto action_i = actions.index({torch::indexing::Slice(), i}); // 获取对应帧的动作
query = action_blocks_[i]->forward(query, action_i);
}
}
// 最终归一化和投影
query = final_ln_->forward(query);
auto output = output_proj_->forward(query); // (B, num_targets, D)
return output;
}
// 训练步骤
torch::Tensor forward(
torch::Tensor frames, // (B, T, C, H, W)
torch::Tensor timestamps, // (B, T)
torch::Tensor actions = {} // (B, T-1, action_dim)
) {
int B = frames.size(0);
int T = frames.size(1);
// 1. 编码视频
auto [all_tokens, all_coords] = encode_video(frames, timestamps);
// 2. 构建训练目标: 预测 t+1 帧的 patch 特征
vector<torch::Tensor> predictions;
vector<torch::Tensor> targets;
for (int t = 0; t < T - 1; ++t) {
// 上下文: 帧 0 到 t
auto context_tokens = all_tokens.index({torch::indexing::Slice(),
torch::indexing::Slice(0, t * all_tokens.size(1) / T)});
auto context_coords = all_coords.index({torch::indexing::Slice(),
torch::indexing::Slice(0, t * all_coords.size(1) / T)});
// 目标: 帧 t+1 的所有 patches
int target_start = (t + 1) * all_tokens.size(1) / T;
int target_end = (t + 2) * all_tokens.size(1) / T;
auto target_tokens = all_tokens.index({torch::indexing::Slice(),
torch::indexing::Slice(target_start, target_end)});
auto target_coords = all_coords.index({torch::indexing::Slice(),
torch::indexing::Slice(target_start, target_end)});
// 获取对应的动作
torch::Tensor action_t;
if (use_actions_ && actions.defined()) {
action_t = actions.index({torch::indexing::Slice(), t});
}
// 预测
auto pred = predict_future(context_tokens, context_coords,
target_coords, action_t);
predictions.push_back(pred);
targets.push_back(target_tokens);
}
// 3. 计算损失
auto pred_cat = torch::cat(predictions, 1);
auto target_cat = torch::cat(targets, 1);
// Smooth L1 损失
auto loss = torch::nn::functional::smooth_l1_loss(pred_cat, target_cat);
return loss;
}
};
// ============================================================================
// 推理: 自回归预测
// ============================================================================
class DINOWorldInference {
private:
shared_ptr<DINOWorld> model_;
torch::Device device_;
public:
DINOWorldInference(shared_ptr<DINOWorld> model, torch::Device device = torch::kCPU)
: model_(model), device_(device) {
model_->eval();
}
// 自回归预测多步未来
vector<torch::Tensor> predict_rollout(
torch::Tensor init_frames, // (1, T_ctx, C, H, W)
torch::Tensor init_timestamps, // (1, T_ctx)
int num_steps,
float dt = 0.1 // 时间步长
) {
vector<torch::Tensor> predictions;
// 初始上下文
auto context_tokens = model_->encode_video(init_frames, init_timestamps).first;
auto context_coords = model_->encode_video(init_frames, init_timestamps).second;
auto timestamps = init_timestamps;
for (int step = 0; step < num_steps; ++step) {
// 预测下一步
float next_time = timestamps[0][-1].item<float>() + dt;
auto target_coords = torch::tensor({{{{next_time}, {-1.0}, {-1.0}}}}).to(device_); // 简化
auto pred = model_->predict_future(context_tokens, context_coords, target_coords);
predictions.push_back(pred);
// 更新上下文 (自回归)
context_tokens = torch::cat({context_tokens, pred}, 1);
timestamps = torch::cat({timestamps,
torch::tensor({{{{next_time}}}}).to(device_)}, 1);
}
return predictions;
}
};
} // namespace dino_world
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// 初始化模型
// 注意: 实际使用需要加载预训练的 DINOv2 编码器
// auto encoder = load_dinov2_model("dinov2_vitb14");
// auto model = make_shared<dino_world::DINOWorld>(encoder);
cout << "DINO-world 架构演示完成" << endl;
return 0;
}5. 实验结果
5.1 密集特征预测
DINO-world 在多个视频预测基准上进行了评估,包括分割和深度预测任务:
| 方法 | VSPW mIoU (↑) | Cityscapes mIoU (↑) | KITTI RMSE (↓) |
|---|---|---|---|
| Copy Last | 42.1 | 39.7 | 4.745 |
| COSMOS-4B | 40.2 | 46.2 | 4.742 |
| COSMOS-12B | 40.7 | 45.9 | 4.617 |
| V-JEPA ViT-H | 4.6 | 12.2 | 5.785 |
| DINO-Foresight | 37.7 | 57.2 | 3.740 |
| DINO-world | 47.0 | 55.1 | 4.268 |
关键发现:DINO-world 在 VSPW 上比第二名高出 6.3% mIoU,验证了”冻结编码器 + 大规模预训练”范式的有效性。
5.2 直观物理理解
使用”惊讶分数”(Surprise Score)评估模型对物理世界的理解能力:
| 模型 | IntPhys | GRASP | InfLevel |
|---|---|---|---|
| COSMOS-4B | 99.5 | 60.1 | 44.8 |
| V-JEPA ViT-H | 89.4 | 73.0 | 59.9 |
| DINO-Foresight | 87.8 | 64.9 | 62.8 |
| DINO-world | 91.3 | 76.0 | 63.7 |
5.3 规划任务
动作条件微调后,DINO-world 在模拟环境中展现了零样本规划能力:
| 方法 | PushT | Wall | PointMaze |
|---|---|---|---|
| Action-only | 49.4 | 91.1 | 61.6 |
| Fine-tuned | 59.4 | 93.8 | 68.7 |
预训练策略带来的性能提升证明了大规模无监督预训练的重要性。
6. 与其他世界模型的比较
6.1 设计范式对比
| 模型 | 潜在空间 | 训练策略 | 参数量 | 特点 |
|---|---|---|---|---|
| DINO-world | DINOv2 | 解耦两阶段 | ~1.1B | 冻结编码器,高效 |
| V-JEPA | 可学习 | 联合训练 | ~200M | 预测头非最优 |
| COSMOS | VAE | 联合训练 | 4-12B | 像素级生成 |
| Genie/Genie2 | 可学习 | 联合训练 | 可变 | 游戏导向 |
| DINO-WM | DINOv2 | 解耦两阶段 | 可变 | 零样本规划 |
6.2 核心优势
- 计算效率:相比 COSMOS 的 22M GPU 小时,DINO-world 仅需约 1B 参数即可达到优异性能
- 冻结编码器:复用 DINOv2 的语义和几何理解,无需联合训练
- 灵活适配:动作块机制支持轻量级下游任务适配
7. 应用场景
DINO-world 的潜在应用包括:
- 自动驾驶:预测道路场景演化,支持规划与决策
- 机器人控制:通过 世界模型规划 实现零样本任务迁移
- 视频理解:密集预测任务(分割、深度估计)的辅助工具
- 游戏 AI:模拟环境动态,支持策略优化
8. 结论与未来方向
DINO-world 证明了在预训练视觉编码器的潜在空间中训练世界模型是一条可行且高效的路径。通过冻结 DINOv2 编码器并在大规模视频数据上训练预测器,模型学习到了跨领域的通用时空动态。
未来研究方向
- 长时预测:当前模型在约1秒后的预测质量显著下降,需要探索多模态预测或不确定性建模
- 数据筛选:研究视频数据的自动化筛选策略,提升预训练效率
- 语言条件:整合语言指令,实现更灵活的任务规范
- 真实机器人:在真实环境中验证动作条件微调和规划能力