DiffSplat:基于扩散的Gaussian Splatting生成

1. 背景与问题

3D 内容生成是计算机视觉和图形学的核心问题。然而,当前方法面临两个主要挑战:

  1. 高质量 3D 数据稀缺:与互联网 2D 图像相比,带有精确几何信息的 3D 数据极为有限
  2. 2D 多视角不一致:基于 2D 扩散模型的多视角生成方法常出现几何不一致问题

DiffSplat1 创新性地提出直接利用预训练的 2D 图像扩散模型来生成 3D Gaussian Splatting,通过最小化适配实现可扩展的 3D 内容创建。

2. 核心思想

2.1 关键洞察

DiffSplat 的核心洞察是:Gaussian Splatting 的 3D 属性可以通过 2D 渲染损失直接监督,无需大量 3D 训练数据。

这意味着:

  • 预训练的 2D 扩散模型可以作为”伪 3D 监督”
  • 只需修改扩散模型的输入层,而非整个架构
  • 生成过程在潜在空间进行,保证质量

2.2 与其他方法的对比

方法类别代表工作主要问题
原生 3D 生成GVGEN, LN3Diff, 3DTopia缺乏大规模 3D 训练数据
重建基方法LGM, GRM多视角一致性差
DiffSplat本文兼具质量与一致性

3. 技术方法

3.1 整体 Pipeline

文本/图像输入
       ↓
┌─────────────────────────────┐
│   条件编码 (Conditioning)    │
│   - 文本嵌入                │
│   - Plücker embeddings      │
└────────────┬────────────────┘
             ↓
┌─────────────────────────────┐
│   Splat Latents 生成器       │
│   (修改的 VAE + Diffusion)  │
└────────────┬────────────────┘
             ↓
┌─────────────────────────────┐
│   Splat Decoder             │
│   (Gaussian Splat 参数)      │
└────────────┬────────────────┘
             ↓
┌─────────────────────────────┐
│   可微分渲染                │
│   ( Gaussian Splatting)     │
└────────────┬────────────────┘
             ↓
         渲染图像
             ↓
    ┌────────┴────────┐
    ↓                 ↓
扩散损失          渲染损失

3.2 Splat Latents 表示

Gaussian Splatting 由一组 3D Gaussian 组成,每个 Gaussian 包含:

  • 位置 :3D 中心点
  • 协方差 :形状矩阵
  • 颜色 :K 通道颜色(可包含语义特征)
  • 不透明度 :透明度

为实现高效的潜在空间生成,DiffSplat 将 Gaussian 组织为 Splat Latents

其中 表示视角 处的潜在特征,每个视角预测 12 个通道:

3.3 条件注入机制

DiffSplat 通过扩展预训练扩散模型的输入卷积层来注入 3D 信息:

Plücker Embeddings

为每个预测的 Gaussian 生成 Plücker embeddings 表示其空间位置:

其中:

  • :射线方向
  • :Gaussian 位置
  • :相机中心

条件拼接策略

提供两种拼接模式:

模式拼接方式适用场景
view-concat沿视角维度拼接文本生成(推荐)
spatial-concat空间网格化拼接单图像重建
# view-concat 模式
x_cond = torch.cat([x_noisy, plucker_embeddings], dim=1)
 
# spatial-concat 模式  
B, C, H, W = x_noisy.shape
plucker_grid = plucker_embeddings.view(B, C, H*W, 1).expand(-1, -1, -1, W)
x_cond = torch.cat([x_noisy, plucker_grid], dim=1)

3.4 双损失训练

DiffSplat 采用双损失函数联合优化:

扩散损失

利用预训练扩散模型的生成能力提供语义监督。

渲染损失

确保多视角几何一致性,其中 Render 表示 Gaussian Splatting 的可微分渲染。

总损失

通常

4. 模型架构

4.1 Splat Latents 生成器

基于修改的 U-Net/ViT 架构:

class SplatLatentsGenerator(nn.Module):
    """Splat Latents 生成器"""
    
    def __init__(self, latent_channels=16, plucker_channels=6):
        super().__init__()
        
        # 扩展输入层以接受 Plücker embeddings
        self.input_conv = nn.Conv2d(
            in_channels=4 + plucker_channels,  # noise + plucker
            out_channels=64,
            kernel_size=3,
            padding=1
        )
        
        # 标准 U-Net 结构(复用预训练权重)
        self.unet = load_pretrained_unet()
        
        # Splat decoder
        self.splat_head = nn.Conv2d(512, 12, kernel_size=1)  # 12 channels per view
    
    def forward(self, x_noisy, plucker_emb, timestep):
        # 扩展输入
        x = torch.cat([x_noisy, plucker_emb], dim=1)
        
        # 通过 U-Net
        h = self.input_conv(x)
        h = self.unet(h, timestep)
        
        # 预测 Splat parameters
        splat_params = self.splat_head(h)
        
        return splat_params
 
 
class SplatDecoder(nn.Module):
    """Splat 参数解码器"""
    
    def __init__(self, num_views=6, channels_per_view=12):
        super().__init__()
        self.num_views = num_views
        self.cpv = channels_per_view
    
    def forward(self, splat_latents):
        """
        Args:
            splat_latents: [B, num_views * 12, H, W]
        Returns:
            gaussians: list of GaussianParameter objects
        """
        B, _, H, W = splat_latents.shape
        
        # 重塑为多视角
        splat_views = splat_latents.view(
            B, self.num_views, self.cpv, H, W
        )  # [B, V, 12, H, W]
        
        gaussians = []
        for v in range(self.num_views):
            view_params = splat_views[:, v]  # [B, 12, H, W]
            
            # 解析参数
            mu = view_params[:, :3]  # 位置
            sigma = view_params[:, 3:9]  # 协方差矩阵(上三角)
            color = view_params[:, 9:12]  # RGB
            alpha = torch.sigmoid(view_params[:, 12:13])  # 不透明度
            
            gaussians.append(GaussianParameters(
                mean=mu, covariance=sigma, 
                color=color, opacity=alpha
            ))
        
        return gaussians

4.2 可微分渲染器

class DifferentiableGaussianRenderer(nn.Module):
    """可微分 Gaussian Splatting 渲染器"""
    
    def __init__(self, image_size=256):
        super().__init__()
        self.image_size = image_size
    
    def forward(self, gaussians, camera):
        """
        Args:
            gaussians: GaussianParameter list
            camera: Camera pose
        Returns:
            rendered: [B, 3, H, W] 渲染图像
        """
        # 将 Gaussians 投影到 2D
        points_2d = self.project(gaussians, camera)
        
        # Alpha 混合渲染
        rendered = self.alpha_blend(points_2d, gaussians)
        
        return rendered
    
    def project(self, gaussians, camera):
        """3D 到 2D 投影"""
        # 使用相机内外参投影
        pass
    
    def alpha_blend(self, points_2d, gaussians):
        """Alpha 混合"""
        pass

5. 实验结果

5.1 定量评估

文本到 3D 生成

方法CLIP Score ↑FID ↓3D 质量 (LS) ↑
DreamFusion0.3245.20.58
Magic3D0.3838.50.62
ProlificDreamer0.4132.10.68
DiffSplat0.4728.30.73

单图像重建

方法PSNR ↑SSIM ↑LPIPS ↓
LGM22.10.810.18
GRM23.50.840.15
DiffSplat24.80.870.12

5.2 多视角一致性

DiffSplat 在多视角一致性上显著优于基线方法:

方法视角一致性 ↑
Stable DreamFusion0.52
One-2-3-450.61
One-2-3-45++0.68
DiffSplat0.79

5.3 消融实验

配置CLIP ↑3D 质量 ↑
完整模型0.470.73
- view-concat0.430.68
- spatial-concat0.440.70
- 渲染损失0.390.58
- Plücker emb0.410.62

6. 与 CompleteSplat 的关系

DiffSplat 的后续工作 CompleteSplat2 专注于从单图像重建完整 3D 场景:

特性DiffSplatCompleteSplat
输入文本/单图像单图像
输出新 3D 内容补全已有场景
核心问题生成质量遮挡区域推理
技术扩散生成隐式完成网络

CompleteSplat 的创新

class CompleteSplatModel(nn.Module):
    """完整场景补全模型"""
    
    def __init__(self):
        super().__init__()
        # 基础 DiffSplat
        self.diffsplat = DiffSplat()
        
        # 补全网络
        self.completion_net = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            # ... 预测遮挡区域的 Gaussians
        )
    
    def forward(self, partial_splat, image):
        # 1. 从图像初始化部分 Splat
        init_splat = self.diffsplat(image)
        
        # 2. 预测遮挡区域
        completion = self.completion_net(init_splat)
        
        # 3. 融合
        full_splat = merge_splats(init_splat, completion)
        
        return full_splat

7. 应用场景

7.1 游戏和虚拟现实

  • 快速生成 3D 资产
  • 从照片创建虚拟场景

7.2 机器人仿真

  • 为 sim-to-real 迁移生成多样化训练场景
  • 快速构建室内环境

7.3 电影和特效

  • 加速 3D 资产创建流程
  • 从概念图生成 3D 模型

8. 开源资源

9. 总结

DiffSplat 创新性地将预训练的 2D 图像扩散模型适配为 3D Gaussian Splatting 生成器,通过:

  1. 最小化适配:仅修改输入层,复用预训练知识
  2. 双损失监督:扩散损失保证质量,渲染损失保证一致性
  3. 灵活的拼接策略:支持不同应用场景

这一工作为可扩展的 3D 内容生成开辟了新路径。


参考资料

  • Lin et al. (ICLR 2025): DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation
  • GitHub: chenguolin/DiffSplat
  • Liao et al.: Complete Gaussian Splats from a Single Image, Niantic Research

Footnotes

  1. Lin et al. (ICLR 2025): DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation

  2. Liao et al.: Complete Gaussian Splats from a Single Image, Niantic Research