NeRF:神经辐射场

概述

NeRF(Neural Radiance Field,神经辐射场)由Mildenhall等人于ECCV 2020提出1,是一种革命性的3D场景表示与视图合成方法。它使用神经网络作为连续函数,将3D场景表示为空间中每个点的颜色和密度分布,通过体积渲染技术从任意角度生成逼真的新视图。

NeRF的核心思想可以用以下数学形式描述:

其中:

  • 是3D空间中的位置坐标
  • 是观察方向(球坐标系)
  • 是该点的颜色
  • 是体密度(不透明度)
  • 是网络参数

作为隐式神经表示的重要应用,NeRF展示了神经网络在连续3D信号表示方面的强大能力。与传统方法(如多视图几何、体素网格)不同,NeRF学习一个连续的场,能够在训练数据点之间的任意位置进行查询,生成高度真实的渲染结果。

核心原理

5D辐射场定义

NeRF将场景表示为一个连续的5D函数:

体密度 仅与位置 有关,表示该点被光线击中而不透过的概率。颜色 同时依赖于位置和观察方向,这使得NeRF能够表达视角相关的外观效果(如镜面反射、透射)。

体积渲染方程

给定相机射线 ,沿射线积分计算像素颜色:

其中:

  • 是累积透射率
  • 分别是近远裁剪平面

离散化近似

连续积分需要离散化近似。沿射线采样 个点:

其中:

  • 是相邻采样点间距
  • 是到第 个采样点的累积透射率

定义alpha值

则渲染方程简化为Alpha合成形式:

位置编码

为什么需要位置编码?

标准MLP难以精确表示高频细节(如纹理、边缘)。这是因为:

  1. 神经网络天然偏好平滑、低频函数
  2. 基于梯度的优化倾向于过度平滑

Mildenhall等人发现,通过将输入坐标映射到高频空间,可以有效解决此问题。

傅里叶特征编码

将3D坐标 编码为:

其中 是编码层数,通常取

import torch
import torch.nn as nn
 
class PositionalEncoding(nn.Module):
    """NeRF位置编码"""
    def __init__(self, in_dim, num_frequencies=10, include_input=True):
        super().__init__()
        self.in_dim = in_dim
        self.num_frequencies = num_frequencies
        self.include_input = include_input
        
        # 计算频率: 2^0, 2^1, ..., 2^(L-1)
        self.freq_bands = 2 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
    
    def forward(self, x):
        """
        Args:
            x: (..., in_dim) 输入坐标
        Returns:
            encoded: (..., in_dim * 2 * num_frequencies) 或包含原始输入
        """
        # 原始输入
        encoded = [x] if self.include_input else []
        
        # 高频编码
        for freq in self.freq_bands:
            encoded.append(torch.sin(freq * torch.pi * x))
            encoded.append(torch.cos(freq * torch.pi * x))
        
        return torch.cat(encoded, dim=-1)
    
    @property
    def output_dim(self):
        dim = self.in_dim * 2 * self.num_frequencies
        if self.include_input:
            dim += self.in_dim
        return dim

视角方向编码

观察方向同样需要编码,但使用较低频率():

class DirectionEncoding(nn.Module):
    """视角方向编码"""
    def __init__(self, num_frequencies=4):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.freq_bands = 2 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
    
    def forward(self, d):
        """
        Args:
            d: (..., 3) 归一化观察方向
        Returns:
            encoded: (..., 2 * 3 * num_frequencies)
        """
        encoded = []
        for freq in self.freq_bands:
            encoded.append(torch.sin(freq * torch.pi * d))
            encoded.append(torch.cos(freq * torch.pi * d))
        return torch.cat(encoded, dim=-1)
    
    @property
    def output_dim(self):
        return 2 * 3 * self.num_frequencies

网络架构

标准NeRF架构

NeRF使用两个阶段网络:

  1. Coarse网络:粗糙采样,预测粗略的密度分布
  2. Fine网络:基于Coarse网络的权重进行重要性采样
import torch
import torch.nn as nn
 
class NeRF(nn.Module):
    """NeRF网络架构"""
    def __init__(self, pos_dim=60, dir_dim=27, hidden_dim=256):
        """
        Args:
            pos_dim: 位置编码维度(3 -> 60)
            dir_dim: 方向编码维度(3 -> 27)
            hidden_dim: 隐藏层维度
        """
        super().__init__()
        
        # 输入层:位置编码 -> 隐藏层
        self.fc1 = nn.Linear(pos_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim + 1)  # 输出特征 + sigma
        
        # 视角方向分支
        self.fc6 = nn.Linear(hidden_dim + dir_dim, hidden_dim // 2)
        self.fc7 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
        self.fc8 = nn.Linear(hidden_dim // 4, 3)  # RGB颜色
        
        self.relu = nn.ReLU()
    
    def forward(self, pos_enc, dir_enc):
        """
        Args:
            pos_enc: (N, pos_dim) 编码后的位置
            dir_enc: (N, dir_dim) 编码后的方向
        Returns:
            rgb: (N, 3) 颜色
            sigma: (N, 1) 体密度
        """
        # 位置编码处理
        x = self.relu(self.fc1(pos_enc))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.fc5(x)
        
        # 分离特征和密度
        feature = x[:, :-1]  # (N, 256)
        sigma = x[:, -1]     # (N,)
        sigma = torch.relu(sigma)  # 密度必须非负
        
        # 方向编码处理
        x = torch.cat([feature, dir_enc], dim=-1)
        x = self.relu(self.fc6(x))
        x = self.relu(self.fc7(x))
        rgb = torch.sigmoid(self.fc8(x))  # 颜色归一化到[0,1]
        
        return rgb, sigma

完整模型封装

class NeRFModel(nn.Module):
    """完整的NeRF模型(包含Coarse和Fine网络)"""
    def __init__(self, pos_dim=60, dir_dim=27, hidden_dim=256):
        super().__init__()
        self.coarse = NeRF(pos_dim, dir_dim, hidden_dim)
        self.fine = NeRF(pos_dim, dir_dim, hidden_dim)
    
    def forward(self, pos_enc, dir_enc):
        rgb_c, sigma_c = self.coarse(pos_enc, dir_enc)
        rgb_f, sigma_f = self.fine(pos_enc, dir_enc)
        return (rgb_c, sigma_c), (rgb_f, sigma_f)

训练过程

分层体积采样

为高效估算积分,NeRF采用分层采样策略:

  1. 粗采样:在射线区间均匀采样 个点
  2. 权重计算:基于密度预测计算采样权重
  3. 细采样:使用逆变换采样得到 个点
def sample_along_ray(rng, near, far, num_samples):
    """沿射线均匀采样"""
    t_vals = torch.linspace(near, far - 1, num_samples, device=rng.device)
    # 添加随机偏移
    mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
    upper = torch.cat([mids, far.unsqueeze(-1)], -1)
    lower = torch.cat([near.unsqueeze(-1), mids], -1)
    t_rand = torch.rand_like(upper)
    t_vals = lower + (upper - lower) * t_rand
    return t_vals
 
def sample_pdf(weights, t_vals, num_samples):
    """逆变换采样(重要性采样)"""
    weights = weights + 1e-5  # 防止零权重
    pdf = weights / weights.sum(dim=-1, keepdim=True)
    cdf = torch.cumsum(pdf[..., :-1], dim=-1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)
    
    # 均匀采样
    u = torch.rand_like(cdf)
    
    # 逆CDF
    indices = torch.searchsorted(cdf, u, right=True)
    t = torch.gather(t_vals, -1, indices.clamp(max=t_vals.shape[-1] - 1))
    
    # 采样点周围插值
    t_lo = torch.gather(t_vals, -1, (indices - 1).clamp(min=0))
    t_hi = torch.gather(t_vals, -1, indices.clamp(max=t_vals.shape[-1] - 1))
    
    t_normalized = (u - torch.gather(cdf, -1, (indices - 1).clamp(min=0))) / \
                   (torch.gather(cdf, -1, indices.clamp(max=cdf.shape[-1] - 1)) - 
                    torch.gather(cdf, -1, (indices - 1).clamp(min=0)) + 1e-5)
    t_fine = t_lo + (t_hi - t_lo) * t_normalized.clamp(0, 1)
    
    return t_fine

渲染函数

def render_rays(model, rays_o, rays_d, near, far, num_coarse=64, num_fine=64):
    """
    渲染射线
    
    Args:
        model: NeRF模型
        rays_o: (N, 3) 射线起点
        rays_d: (N, 3) 射线方向
        near, far: 近远裁剪平面
        num_coarse, num_fine: 采样点数
    Returns:
        rgb_map: (N, 3) 渲染颜色
        depth_map: (N,) 深度
    """
    # Coarse采样
    t_vals_coarse = sample_along_ray(rays_o.new_empty(0), near, far, num_coarse)
    pts_coarse = rays_o[..., None, :] + rays_d[..., None, :] * t_vals_coarse[..., None, :]
    pts_coarse = pts_coarse.reshape(-1, 3)
    
    # Coarse网络前向传播
    rgb_c, sigma_c = model(pts_coarse, pts_coarse)  # 简化:实际需要编码
    
    # 重塑
    rgb_c = rgb_c.reshape(-1, num_coarse, 3)
    sigma_c = sigma_c.reshape(-1, num_coarse)
    
    # 计算权重
    delta = torch.cat([t_vals_coarse[1:] - t_vals_coarse[:-1], 
                       torch.tensor([1e10], device=t_vals_coarse.device).expand(t_vals_coarse.shape[:-1] + (1,))], dim=-1)
    alpha_c = 1.0 - torch.exp(-sigma_c * delta)
    T_c = torch.cumprod(torch.cat([torch.ones_like(alpha_c[..., :1]), 1.0 - alpha_c[..., :-1]], dim=-1), dim=-1)
    weights_c = T_c * alpha_c
    
    # Fine采样
    t_vals_fine = sample_pdf(weights_c[..., 1:-1].detach(), t_vals_coarse, num_fine)
    pts_fine = rays_o[..., None, :] + rays_d[..., None, :] * t_vals_fine[..., None, :]
    pts_fine = pts_fine.reshape(-1, 3)
    
    # Fine网络前向传播
    rgb_f, sigma_f = model(pts_fine, pts_fine)
    
    # 重塑
    rgb_f = rgb_f.reshape(-1, num_fine, 3)
    sigma_f = sigma_f.reshape(-1, num_fine)
    
    # 渲染
    rgb_map, depth_map = volumetric_render(rgb_f, sigma_f, t_vals_fine, rays_d)
    
    return rgb_map, depth_map
 
def volumetric_render(rgb, sigma, t_vals, rays_d):
    """体积渲染"""
    delta = torch.cat([t_vals[1:] - t_vals[:-1], 
                       torch.tensor([1e10], device=t_vals.device).expand(t_vals.shape[:-1] + (1,))], dim=-1)
    alpha = 1.0 - torch.exp(-sigma * delta)
    T = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), 1.0 - alpha[..., :-1]], dim=-1), dim=-1)
    weights = T * alpha
    
    rgb_map = (weights[..., None] * rgb).sum(dim=-2)
    depth_map = (weights * t_vals).sum(dim=-1)
    
    return rgb_map, depth_map

损失函数

训练使用简单的均方误差损失,同时优化Coarse和Fine网络:

class NeRFLoss(nn.Module):
    """NeRF损失函数"""
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, rgb_coarse, rgb_fine, target_rgb):
        """
        Args:
            rgb_coarse: (N, 3) Coarse网络预测
            rgb_fine: (N, 3) Fine网络预测
            target_rgb: (N, 3) 真实颜色
        """
        loss_coarse = self.mse(rgb_coarse, target_rgb)
        loss_fine = self.mse(rgb_fine, target_rgb)
        return loss_coarse + loss_fine

完整训练循环

def train_nerf(model, images, poses, intrinsics, optimizer, device, num_iterations=100000):
    """
    训练NeRF
    
    Args:
        model: NeRF模型
        images: (V, H, W, 3) 输入图像
        poses: (V, 4, 4) 相机外参
        intrinsics: (V, 3, 3) 相机内参
        optimizer: 优化器
        device: 计算设备
    """
    model.train()
    pos_enc = PositionalEncoding(3, 10)
    dir_enc = DirectionEncoding(4)
    loss_fn = NeRFLoss()
    
    batch_size = 1  # 实际中可调整
    h, w = images.shape[1:3]
    
    for iteration in range(num_iterations):
        # 随机选择图像
        img_idx = torch.randint(0, len(images), (batch_size,))
        image = images[img_idx].to(device)
        pose = poses[img_idx].to(device)
        
        # 生成射线
        rays_o, rays_d = get_rays(h, w, intrinsics[img_idx], pose)
        rays_o = rays_o.reshape(-1, 3)
        rays_d = rays_d.reshape(-1, 3)
        
        # 采样像素
        pixel_indices = torch.randint(0, h * w, (1024,))
        rays_o = rays_o[pixel_indices]
        rays_d = rays_d[pixel_indices]
        target_rgb = image.reshape(-1, 3)[pixel_indices]
        
        # 编码
        pos_encoded = pos_enc(rays_o)
        dir_encoded = dir_enc(rays_d)
        
        # 前向传播
        (rgb_c, sigma_c), (rgb_f, sigma_f) = model(pos_encoded, dir_encoded)
        
        # 计算损失
        loss = loss_fn(rgb_c, rgb_f, target_rgb)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 学习率衰减
        if iteration % 10000 == 0:
            print(f"Iteration {iteration}, Loss: {loss.item():.6f}")
    
    return model
 
def get_rays(h, w, intrinsics, pose):
    """生成射线"""
    i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
    i = i.float()
    j = j.float()
    
    # 像素坐标转相机坐标
    dirs = torch.stack([(i - intrinsics[0, 2]) / intrinsics[0, 0],
                       -(j - intrinsics[1, 2]) / intrinsics[1, 1],
                       -torch.ones_like(i)], dim=-1)
    
    # 变换到世界坐标
    rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], dim=-1)
    rays_o = pose[:3, 3].expand(rays_d.shape)
    
    return rays_o, rays_d

改进与变体

Mip-NeRF:抗锯齿

Mip-NeRF2使用锥形追踪替代射线追踪,通过积分位置编码(Integrated Positional Encoding, IPE)实现抗锯齿:

IPE允许MLP以有限宽度感知场景区域,解决混叠问题。

Instant NGP:多分辨率哈希编码

Instant NGP3使用多分辨率哈希表实现实时渲染:

class MultiResolutionHash(nn.Module):
    """多分辨率哈希编码"""
    def __init__(self, level_dim=2, num_levels=16, hash_size=2**19):
        super().__init__()
        self.num_levels = num_levels
        self.level_dim = level_dim
        
        # 可学习的哈希表
        self.tables = nn.ParameterList([
            nn.Parameter(torch.randn(hash_size, level_dim) * 0.01)
            for _ in range(num_levels)
        ])
        
        # 分辨率
        self.resolutions = [16 * (2 ** i) for i in range(num_levels)]
    
    def forward(self, x):
        """
        Args:
            x: (N, 3) 归一化坐标 [0, 1]
        Returns:
            features: (N, num_levels * level_dim)
        """
        features = []
        for level in range(self.num_levels):
            # 计算该级别的分辨率
            resolution = self.resolutions[level]
            
            # 缩放坐标
            x_scaled = x * resolution
            
            # 获取网格顶点
            x0 = torch.floor(x_scaled).long()
            x1 = x0 + 1
            
            # 哈希顶点坐标
            hash_x0 = self._hash(x0, resolution)
            hash_x1 = self._hash(x1, resolution)
            
            # 三线性插值
            x_frac = x_scaled - x0.float()
            feat00 = self.tables[level][hash_x0[:, 0]]
            feat01 = self.tables[level][hash_x0[:, 1]]
            feat10 = self.tables[level][hash_x1[:, 0]]
            feat11 = self.tables[level][hash_x1[:, 1]]
            
            feat0 = feat00 * (1 - x_frac[:, 1:]) + feat01 * x_frac[:, 1:]
            feat1 = feat10 * (1 - x_frac[:, 1:]) + feat11 * x_frac[:, 1:]
            feat = feat0 * (1 - x_frac[:, :1]) + feat1 * x_frac[:, :1]
            
            features.append(feat)
        
        return torch.cat(features, dim=-1)
    
    def _hash(self, coords, resolution):
        """简单哈希函数"""
        x = coords[:, 0].float() / resolution
        y = coords[:, 1].float() / resolution
        return ((x * 374761393 + y * 668265263) % self.tables[0].shape[0]).long()

NerFacto:工业级重建

NerFacto结合了多种技术:

  • 多分辨率哈希编码(来自Instant NGP)
  • 场景边界自适应
  • 基于外观码的条件NeRF
  • 深度监督

应用场景

新视图合成

NeRF最直接的应用是从少量图像生成任意角度的逼真视图:

def novel_view_synthesis(model, target_pose, intrinsics, h, w, device):
    """生成新视图"""
    model.eval()
    with torch.no_grad():
        # 生成目标视角的射线
        rays_o, rays_d = get_rays(h, w, intrinsics, target_pose)
        rays_o = rays_o.reshape(-1, 3).to(device)
        rays_d = rays_d.reshape(-1, 3).to(device)
        
        # 编码
        pos_enc = PositionalEncoding(3, 10)
        dir_enc = DirectionEncoding(4)
        pos_encoded = pos_enc(rays_o)
        dir_encoded = dir_enc(rays_d)
        
        # 渲染
        rgb, depth = render_rays(model, rays_o, rays_d, 0.0, 1.0)
        
        return rgb.reshape(h, w, 3), depth.reshape(h, w)

场景编辑与操作

基于密度场可以实现场景编辑:

  1. 物体移除:将特定区域密度置零
  2. 物体插入:从另一场景复制密度场
  3. 场景变形:对位置编码应用变换
def remove_object_from_nerf(model, mask, threshold=0.5):
    """从NeRF中移除物体"""
    def edited_forward(pos_enc, dir_enc, mask):
        rgb, sigma = original_forward(pos_enc, dir_enc)
        # 根据mask掩码降低密度
        sigma[mask > threshold] *= 0.0
        return rgb, sigma
    
    return edited_forward

语义分割

将NeRF扩展为多任务学习,同时输出语义标签:

class SemanticNeRF(nn.Module):
    """语义NeRF"""
    def __init__(self, num_classes, pos_dim=60, dir_dim=27, hidden_dim=256):
        super().__init__()
        self.nerf = NeRF(pos_dim, dir_dim, hidden_dim)
        self.semantic_head = nn.Sequential(
            nn.Linear(hidden_dim + pos_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
        self.num_classes = num_classes
    
    def forward(self, pos_enc, dir_enc):
        # NeRF特征
        feature = self.nerf.get_feature(pos_enc)
        rgb, sigma = self.nerf(pos_enc, dir_enc)
        
        # 语义标签
        semantic = self.semantic_head(torch.cat([feature, pos_enc], dim=-1))
        semantic = torch.softmax(semantic, dim=-1)
        
        return rgb, sigma, semantic

动态场景重建

D-NeRF和Neural Scene Flow等方法处理运动场景:

其中 是时间 的位移场。

局限性

计算复杂度

  • 单帧渲染需数百万次网络查询
  • 训练收敛慢(数小时到数天)
  • 难以实时渲染

几何歧义

NeRF难以处理:

  • 镜面反射(视角相关 vs 材质相关)
  • 半透明物体
  • 稀疏视图下的几何歧义

泛化能力

  • 每个场景需独立训练
  • 泛化到新场景需额外技术(如PixelNeRF)

参考文献


相关链接:implicit-neural-representations | siren-networks

Footnotes

  1. Mildenhall, B., Srinivasan, P. P., Tancik, M., Barron, J. T., Ramamoorthi, R., & Ng, R. (2020). NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. European Conference on Computer Vision (ECCV).

  2. Barron, J. T., Mildenhall, B., Tancik, M., Hedman, P., Martin-Brualla, R., & Srinivasan, P. P. (2021). Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields. CVPR 2022.

  3. Müller, T., Evans, A., Schied, C., & Keller, A. (2022). Instant Neural Graphics Primitives with a Multiresolution Hash Encoding. ACM Transactions on Graphics (SIGGRAPH).