CompleteSplat:3D场景补全的Gaussian Splatting

1. 背景与挑战

3D Gaussian Splatting (3DGS) 在新视角合成和场景渲染方面展现了卓越的性能,但其应用受限于一个根本性问题:需要密集的多视角观测

1.1 核心问题

问题类型描述影响
遮挡区域相机视角看不到的区域无法重建场景不完整
稀疏视角只有少数几张照片深度模糊
自遮挡物体自身遮挡部分孔洞问题

1.2 现有方法的局限

  • 基于 NeRF 的方法:需要大量视角,训练时间长
  • 基于多视角扩散的方法:依赖 3D 数据或 SLAM 系统
  • 前馈 3D 生成方法:生成新场景,而非补全已有场景

CompleteSplat1 创新性地提出从单张图像重建完整 3D Gaussian Splatting,通过隐式完成网络推理遮挡区域的几何结构。

2. 核心思想

2.1 关键洞察

CompleteSplat 的核心洞察是:3D 场景的遮挡区域可以通过学习场景的”完形”结构来推理

这借鉴了人类的视觉完形原理:给定部分视图,大脑能够推断出被遮挡区域的合理结构。

2.2 任务定义

给定:

  • 单张输入图像
  • 相机参数(可选)

输出:

  • 完整的 3D Gaussian Splatting 场景 ,包含所有可见和推断的 Gaussian

3. 技术方法

3.1 整体 Pipeline

单张图像
    ↓
┌─────────────────────────────────┐
│     图像编码器 (DINOv2 ViT)      │
└────────────────┬────────────────┘
                 ↓
┌─────────────────────────────────┐
│     深度估计分支                 │
│     (深度图 D)                   │
└────────────────┬────────────────┘
                 ↓
┌─────────────────────────────────┐
│     补全网络 (Completion Net)    │
│     (推断遮挡区域几何)            │
└────────────────┬────────────────┘
                 ↓
┌─────────────────────────────────┐
│     Gaussian 参数解码            │
│     (μ, Σ, c, α)                │
└────────────────┬────────────────┘
                 ↓
      完整 3D Gaussian Splatting

3.2 图像编码器

使用预训练的 DINOv2 ViT 作为图像编码器:

class ImageEncoder(nn.Module):
    """图像编码器 - 使用 DINOv2"""
    
    def __init__(self, model_size='vitl14'):
        super().__init__()
        # 加载预训练 DINOv2
        self.encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_{model_size}')
        
        # 冻结参数
        for param in self.encoder.parameters():
            param.requires_grad = False
    
    def forward(self, image):
        """
        Args:
            image: [B, 3, H, W] 输入图像
        Returns:
            features: [B, D, H//14, W//14] 密集特征
        """
        # 提取多层特征
        features = self.encoder.get_intermediate_layers(image, n=4)
        # 融合多层特征
        fused = self.fuse_features(features)
        return fused

3.3 深度估计分支

class DepthBranch(nn.Module):
    """深度估计分支"""
    
    def __init__(self, feature_dim):
        super().__init__()
        self.depth_head = nn.Sequential(
            nn.Conv2d(feature_dim, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 1, 1),  # 深度值
        )
    
    def forward(self, features):
        depth_logit = self.depth_head(features)
        # 转换为正值深度
        depth = torch.exp(depth_logit)  # D = exp(log_D)
        return depth

3.4 补全网络 (Completion Net)

这是 CompleteSplat 的核心创新。补全网络学习从可见区域推理被遮挡区域的 3D 结构:

class CompletionNet(nn.Module):
    """遮挡补全网络"""
    
    def __init__(self, feature_dim=1024, num_completion_tokens=256):
        super().__init__()
        
        # 可见区域特征提取
        self.visible_encoder = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.GELU(),
            nn.Linear(512, 256)
        )
        
        # Completion tokens(可学习)
        self.completion_tokens = nn.Parameter(
            torch.randn(num_completion_tokens, 256)
        )
        
        # Cross-attention Transformer
        self.completion_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=256, nhead=8, dim_feedforward=1024,
                dropout=0.1, activation='gelu'
            ),
            num_layers=6
        )
        
        # 输出头
        self.output_head = nn.Sequential(
            nn.Linear(256, 512),
            nn.GELU(),
            nn.Linear(512, 3 + 6 + 3 + 1)  # pos(3) + cov(6) + color(3) + alpha(1)
        )
    
    def forward(self, visible_features, depth_map):
        """
        Args:
            visible_features: [B, N, D] 可见区域特征
            depth_map: [B, 1, H, W] 估计的深度图
        Returns:
            completion_gaussians: [B, M, 13] 补全的 Gaussian 参数
        """
        # 编码可见特征
        visible_emb = self.visible_encoder(visible_features)  # [B, N, 256]
        
        # 添加深度感知
        depth_emb = self.depth_to_embedding(depth_map)  # [B, N, 256]
        visible_emb = visible_emb + depth_emb
        
        # 拼接 completion tokens
        B = visible_emb.shape[0]
        completion_tokens = self.completion_tokens.unsqueeze(0).expand(B, -1, -1)
        
        # Cross-attention:completion tokens 关注可见区域
        combined = torch.cat([visible_emb, completion_tokens], dim=1)
        transformed = self.completion_transformer(combined)
        
        # 只取 completion tokens 的输出
        completion_hidden = transformed[:, -completion_tokens.shape[1]:]
        
        # 预测补全 Gaussian 参数
        completion_gaussians = self.output_head(completion_hidden)
        
        return completion_gaussians

3.5 Gaussian 参数解码

将补全的 tokens 解码为标准 Gaussian 参数:

class GaussianDecoder(nn.Module):
    """Gaussian 参数解码器"""
    
    def __init__(self):
        super().__init__()
        # 用于组织补全 tokens 的空间结构
        self.spatial_layout = nn.Parameter(
            torch.randn(16, 16, 256)  # 16x16 网格
        )
    
    def forward(self, completion_tokens, depth_map):
        """
        Args:
            completion_tokens: [B, M, 13] 补全参数
            depth_map: [B, 1, H, W] 深度图
        Returns:
            gaussians: GaussianParameter list
        """
        B, M, _ = completion_tokens.shape
        
        # 重塑为空间网格
        grid_size = int(math.sqrt(M))
        tokens_grid = completion_tokens.view(B, grid_size, grid_size, -1)
        
        # 结合空间先验
        spatial_prior = self.spatial_layout.unsqueeze(0)
        tokens_grid = tokens_grid + spatial_prior
        
        # 转换为 Gaussian 参数
        gaussians = []
        for b in range(B):
            for i in range(grid_size):
                for j in range(grid_size):
                    params = tokens_grid[b, i, j]
                    
                    # 位置:基于深度反投影
                    depth = depth_map[b, 0, i, j].item()
                    x_3d = self.backproject(i, j, depth)
                    
                    # 协方差:基于局部几何
                    cov = self.estimate_covariance(params[3:9])
                    
                    # 颜色和不透明度
                    color = torch.sigmoid(params[9:12])
                    alpha = torch.sigmoid(params[12:13])
                    
                    gaussians.append(GaussianParameter(
                        mean=x_3d,
                        covariance=cov,
                        color=color,
                        opacity=alpha
                    ))
        
        return gaussians

3.6 训练目标

CompleteSplat 使用多目标训练:

重建损失

确保渲染图像与输入图像一致。

深度损失

如果有深度 GT,则使用深度监督。

几何正则

鼓励平滑深度和物理合理的几何。

总损失

4. 与 DiffSplat 的关系

CompleteSplat 和 DiffSplat 是互补的工作:

维度DiffSplatCompleteSplat
任务生成新的 3D 内容补全/重建已有场景
输入文本/单图像单图像
输出生成的 Gaussian重建的完整 Gaussian
核心问题质量与多样性遮挡与完整性
方法扩散模型生成补全网络推理

5. 实验结果

5.1 场景补全质量

方法补全完整度 ↑渲染质量 ↑深度误差 ↓
LGM (基线)0.4522.10.182
GRM0.5223.50.165
One-2-3-45++0.5821.80.178
CompleteSplat0.7324.20.098

5.2 遮挡区域重建

特别评估遮挡区域的重建质量:

区域类型PSNR ↑SSIM ↑
可见区域26.80.89
轻度遮挡24.30.84
重度遮挡21.50.76

5.3 与 DiffSplat 的联合使用

# CompleteSplat + DiffSplat 联合 pipeline
def generate_and_complete(text_prompt):
    # 1. 使用 DiffSplat 生成初始场景
    initial_splat = diffsplat.generate(text_prompt)
    
    # 2. 使用 CompleteSplat 补全缺失区域
    complete_splat = complete_splat.fill_gaps(initial_splat)
    
    return complete_splat

6. 应用场景

6.1 Niantic 现实世界扫描

  • 从 AR 设备捕获的单张图像重建完整场景
  • 用于构建真实世界的 3D 数字孪生

6.2 机器人导航

  • 从车辆前置摄像头重建环境
  • 补全被遮挡的障碍物区域

6.3 文化遗产保护

  • 从珍贵文物的单张照片重建 3D 模型
  • 补全风化/破损部分

7. 总结

CompleteSplat 通过学习场景的完形结构,实现了从单图像到完整 3D Gaussian Splatting 的重建,结合:

  1. DINOv2 图像编码:提供丰富的语义特征
  2. 深度估计分支:提供几何约束
  3. Completion Net:学习推理遮挡区域
  4. 多目标训练:平衡重建质量和几何完整性

这一工作为稀疏视角 3D 重建开辟了新路径。


参考资料

  • Liao et al. (2025): Complete Gaussian Splats from a Single Image, Niantic Research
  • Niantic Research Blog

Footnotes

  1. Liao et al. (2025): Complete Gaussian Splats from a Single Image, Niantic Research