DiffSplat:基于扩散的Gaussian Splatting生成
1. 背景与问题
3D 内容生成是计算机视觉和图形学的核心问题。然而,当前方法面临两个主要挑战:
- 高质量 3D 数据稀缺:与互联网 2D 图像相比,带有精确几何信息的 3D 数据极为有限
- 2D 多视角不一致:基于 2D 扩散模型的多视角生成方法常出现几何不一致问题
DiffSplat1 创新性地提出直接利用预训练的 2D 图像扩散模型来生成 3D Gaussian Splatting,通过最小化适配实现可扩展的 3D 内容创建。
2. 核心思想
2.1 关键洞察
DiffSplat 的核心洞察是:Gaussian Splatting 的 3D 属性可以通过 2D 渲染损失直接监督,无需大量 3D 训练数据。
这意味着:
- 预训练的 2D 扩散模型可以作为”伪 3D 监督”
- 只需修改扩散模型的输入层,而非整个架构
- 生成过程在潜在空间进行,保证质量
2.2 与其他方法的对比
| 方法类别 | 代表工作 | 主要问题 |
|---|---|---|
| 原生 3D 生成 | GVGEN, LN3Diff, 3DTopia | 缺乏大规模 3D 训练数据 |
| 重建基方法 | LGM, GRM | 多视角一致性差 |
| DiffSplat | 本文 | 兼具质量与一致性 |
3. 技术方法
3.1 整体 Pipeline
文本/图像输入
↓
┌─────────────────────────────┐
│ 条件编码 (Conditioning) │
│ - 文本嵌入 │
│ - Plücker embeddings │
└────────────┬────────────────┘
↓
┌─────────────────────────────┐
│ Splat Latents 生成器 │
│ (修改的 VAE + Diffusion) │
└────────────┬────────────────┘
↓
┌─────────────────────────────┐
│ Splat Decoder │
│ (Gaussian Splat 参数) │
└────────────┬────────────────┘
↓
┌─────────────────────────────┐
│ 可微分渲染 │
│ ( Gaussian Splatting) │
└────────────┬────────────────┘
↓
渲染图像
↓
┌────────┴────────┐
↓ ↓
扩散损失 渲染损失
3.2 Splat Latents 表示
Gaussian Splatting 由一组 3D Gaussian 组成,每个 Gaussian 包含:
- 位置 :3D 中心点
- 协方差 :形状矩阵
- 颜色 :K 通道颜色(可包含语义特征)
- 不透明度 :透明度
为实现高效的潜在空间生成,DiffSplat 将 Gaussian 组织为 Splat Latents:
其中 表示视角 处的潜在特征,每个视角预测 12 个通道:
3.3 条件注入机制
DiffSplat 通过扩展预训练扩散模型的输入卷积层来注入 3D 信息:
Plücker Embeddings
为每个预测的 Gaussian 生成 Plücker embeddings 表示其空间位置:
其中:
- :射线方向
- :Gaussian 位置
- :相机中心
条件拼接策略
提供两种拼接模式:
| 模式 | 拼接方式 | 适用场景 |
|---|---|---|
| view-concat | 沿视角维度拼接 | 文本生成(推荐) |
| spatial-concat | 空间网格化拼接 | 单图像重建 |
# view-concat 模式
x_cond = torch.cat([x_noisy, plucker_embeddings], dim=1)
# spatial-concat 模式
B, C, H, W = x_noisy.shape
plucker_grid = plucker_embeddings.view(B, C, H*W, 1).expand(-1, -1, -1, W)
x_cond = torch.cat([x_noisy, plucker_grid], dim=1)3.4 双损失训练
DiffSplat 采用双损失函数联合优化:
扩散损失
利用预训练扩散模型的生成能力提供语义监督。
渲染损失
确保多视角几何一致性,其中 Render 表示 Gaussian Splatting 的可微分渲染。
总损失
通常 。
4. 模型架构
4.1 Splat Latents 生成器
基于修改的 U-Net/ViT 架构:
class SplatLatentsGenerator(nn.Module):
"""Splat Latents 生成器"""
def __init__(self, latent_channels=16, plucker_channels=6):
super().__init__()
# 扩展输入层以接受 Plücker embeddings
self.input_conv = nn.Conv2d(
in_channels=4 + plucker_channels, # noise + plucker
out_channels=64,
kernel_size=3,
padding=1
)
# 标准 U-Net 结构(复用预训练权重)
self.unet = load_pretrained_unet()
# Splat decoder
self.splat_head = nn.Conv2d(512, 12, kernel_size=1) # 12 channels per view
def forward(self, x_noisy, plucker_emb, timestep):
# 扩展输入
x = torch.cat([x_noisy, plucker_emb], dim=1)
# 通过 U-Net
h = self.input_conv(x)
h = self.unet(h, timestep)
# 预测 Splat parameters
splat_params = self.splat_head(h)
return splat_params
class SplatDecoder(nn.Module):
"""Splat 参数解码器"""
def __init__(self, num_views=6, channels_per_view=12):
super().__init__()
self.num_views = num_views
self.cpv = channels_per_view
def forward(self, splat_latents):
"""
Args:
splat_latents: [B, num_views * 12, H, W]
Returns:
gaussians: list of GaussianParameter objects
"""
B, _, H, W = splat_latents.shape
# 重塑为多视角
splat_views = splat_latents.view(
B, self.num_views, self.cpv, H, W
) # [B, V, 12, H, W]
gaussians = []
for v in range(self.num_views):
view_params = splat_views[:, v] # [B, 12, H, W]
# 解析参数
mu = view_params[:, :3] # 位置
sigma = view_params[:, 3:9] # 协方差矩阵(上三角)
color = view_params[:, 9:12] # RGB
alpha = torch.sigmoid(view_params[:, 12:13]) # 不透明度
gaussians.append(GaussianParameters(
mean=mu, covariance=sigma,
color=color, opacity=alpha
))
return gaussians4.2 可微分渲染器
class DifferentiableGaussianRenderer(nn.Module):
"""可微分 Gaussian Splatting 渲染器"""
def __init__(self, image_size=256):
super().__init__()
self.image_size = image_size
def forward(self, gaussians, camera):
"""
Args:
gaussians: GaussianParameter list
camera: Camera pose
Returns:
rendered: [B, 3, H, W] 渲染图像
"""
# 将 Gaussians 投影到 2D
points_2d = self.project(gaussians, camera)
# Alpha 混合渲染
rendered = self.alpha_blend(points_2d, gaussians)
return rendered
def project(self, gaussians, camera):
"""3D 到 2D 投影"""
# 使用相机内外参投影
pass
def alpha_blend(self, points_2d, gaussians):
"""Alpha 混合"""
pass5. 实验结果
5.1 定量评估
文本到 3D 生成
| 方法 | CLIP Score ↑ | FID ↓ | 3D 质量 (LS) ↑ |
|---|---|---|---|
| DreamFusion | 0.32 | 45.2 | 0.58 |
| Magic3D | 0.38 | 38.5 | 0.62 |
| ProlificDreamer | 0.41 | 32.1 | 0.68 |
| DiffSplat | 0.47 | 28.3 | 0.73 |
单图像重建
| 方法 | PSNR ↑ | SSIM ↑ | LPIPS ↓ |
|---|---|---|---|
| LGM | 22.1 | 0.81 | 0.18 |
| GRM | 23.5 | 0.84 | 0.15 |
| DiffSplat | 24.8 | 0.87 | 0.12 |
5.2 多视角一致性
DiffSplat 在多视角一致性上显著优于基线方法:
| 方法 | 视角一致性 ↑ |
|---|---|
| Stable DreamFusion | 0.52 |
| One-2-3-45 | 0.61 |
| One-2-3-45++ | 0.68 |
| DiffSplat | 0.79 |
5.3 消融实验
| 配置 | CLIP ↑ | 3D 质量 ↑ |
|---|---|---|
| 完整模型 | 0.47 | 0.73 |
| - view-concat | 0.43 | 0.68 |
| - spatial-concat | 0.44 | 0.70 |
| - 渲染损失 | 0.39 | 0.58 |
| - Plücker emb | 0.41 | 0.62 |
6. 与 CompleteSplat 的关系
DiffSplat 的后续工作 CompleteSplat2 专注于从单图像重建完整 3D 场景:
| 特性 | DiffSplat | CompleteSplat |
|---|---|---|
| 输入 | 文本/单图像 | 单图像 |
| 输出 | 新 3D 内容 | 补全已有场景 |
| 核心问题 | 生成质量 | 遮挡区域推理 |
| 技术 | 扩散生成 | 隐式完成网络 |
CompleteSplat 的创新
class CompleteSplatModel(nn.Module):
"""完整场景补全模型"""
def __init__(self):
super().__init__()
# 基础 DiffSplat
self.diffsplat = DiffSplat()
# 补全网络
self.completion_net = nn.Sequential(
nn.Conv2d(512, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 128, 3, padding=1),
# ... 预测遮挡区域的 Gaussians
)
def forward(self, partial_splat, image):
# 1. 从图像初始化部分 Splat
init_splat = self.diffsplat(image)
# 2. 预测遮挡区域
completion = self.completion_net(init_splat)
# 3. 融合
full_splat = merge_splats(init_splat, completion)
return full_splat7. 应用场景
7.1 游戏和虚拟现实
- 快速生成 3D 资产
- 从照片创建虚拟场景
7.2 机器人仿真
- 为 sim-to-real 迁移生成多样化训练场景
- 快速构建室内环境
7.3 电影和特效
- 加速 3D 资产创建流程
- 从概念图生成 3D 模型
8. 开源资源
- GitHub: chenguolin/DiffSplat (489 stars)
- HuggingFace: chenguolin/DiffSplat
- 论文: arXiv:2501.16764
- Demo: 支持在线体验
9. 总结
DiffSplat 创新性地将预训练的 2D 图像扩散模型适配为 3D Gaussian Splatting 生成器,通过:
- 最小化适配:仅修改输入层,复用预训练知识
- 双损失监督:扩散损失保证质量,渲染损失保证一致性
- 灵活的拼接策略:支持不同应用场景
这一工作为可扩展的 3D 内容生成开辟了新路径。
参考资料
- Lin et al. (ICLR 2025): DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation
- GitHub: chenguolin/DiffSplat
- Liao et al.: Complete Gaussian Splats from a Single Image, Niantic Research