SIREN:正弦激活网络
论文概述
SIREN(Sinusoidal Representation Networks)由Sitzmann等人于NeurIPS 2020提出1,是隐式神经表示领域的重要里程碑。与使用ReLU、Tanh等传统激活函数的MLP不同,SIREN使用正弦函数作为激活函数,在表示自然信号(图像、音频、视频、3D场景)方面展现出卓越的能力。
核心贡献
- 周期激活函数:利用正弦激活函数的周期性实现高质量隐式表示
- 渐变保持性质:导数仍是正弦函数,保持与原函数相同的表达能力
- 统一框架:适用于图像、音频、Poisson方程、视频等多样化信号
- 初始化策略:基于谱分析的初始化方法,确保各层激活分布合理
正弦激活函数
数学定义
SIREN的核心是正弦激活函数:
其中 是权重矩阵, 是偏置向量。对于整个网络:
渐变保持性质
正弦函数求导后仍是正弦函数(差一个常数因子):
这意味着:
- SIREN的梯度场仍是某个正弦函数的输出
- 可以用另一个SIREN表示原函数的导数
- 适合需要导数信息的物理场景
与其他激活函数的对比
| 激活函数 | 表达式 | 导数形式 | 梯度性质 |
|---|---|---|---|
| ReLU | 指示函数 | 分段常数 | |
| Tanh | 有界衰减 | ||
| GELU | 复杂 | 近似0/1 | |
| Sin | 仍是正弦 |
网络架构
基础SIREN架构
import torch
import torch.nn as nn
class SineActivation(nn.Module):
"""正弦激活函数"""
def __init__(self):
super().__init__()
def forward(self, x):
return torch.sin(x)
class Siren(nn.Module):
"""基础SIREN网络"""
def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList()
# 输入层
self.layers.append(nn.Linear(in_dim, hidden_dim))
# 隐藏层
for _ in range(num_layers - 2):
self.layers.append(nn.Linear(hidden_dim, hidden_dim))
# 输出层
self.layers.append(nn.Linear(hidden_dim, out_dim))
self.activation = SineActivation()
def forward(self, x):
for i, layer in enumerate(self.layers[:-1]):
x = self.activation(layer(x))
return self.layers[-1](x)跳跃连接SIREN
对于复杂信号,加入跳跃连接:
class SkipConnectionSiren(nn.Module):
"""带跳跃连接的SIREN"""
def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
super().__init__()
self.first_layer = nn.Linear(in_dim, hidden_dim)
self.skip_layer = nn.Linear(in_dim, hidden_dim) # 直接跳跃
self.hidden_layers = nn.ModuleList([
nn.Linear(hidden_dim, hidden_dim)
for _ in range(num_layers - 2)
])
self.final_layer = nn.Linear(hidden_dim, out_dim)
self.activation = SineActivation()
def forward(self, x):
x_skip = self.skip_layer(x) # 跳跃分支
x = self.activation(self.first_layer(x)) # 主分支
for layer in self.hidden_layers:
x = self.activation(layer(x) + x_skip) # 融合
return self.final_layer(x)多输出SIREN
class MultiOutputSiren(nn.Module):
"""多输出SIREN(用于NeRF等场景)"""
def __init__(self, in_dim, hidden_dim, out_dim_spatial, out_dim_extra, num_layers):
super().__init__()
self.shared_layers = nn.ModuleList([
nn.Linear(in_dim, hidden_dim) if i == 0
else nn.Linear(hidden_dim, hidden_dim)
for i in range(num_layers - 1)
])
self.density_head = nn.Linear(hidden_dim, out_dim_spatial)
self.feature_head = nn.Linear(hidden_dim, out_dim_extra)
self.activation = SineActivation()
def forward(self, x):
for layer in self.shared_layers:
x = self.activation(layer(x))
density = self.density_head(x)
features = self.feature_head(x)
return density, features初始化策略
SIREN的成功很大程度上依赖于精心设计的初始化。核心思想是:使每层的输入输出具有相似的谱特性。
理论分析
考虑单层 ,其中 。
均匀分布初始化(常用于Tanh网络):
对于 ,有:
如果 服从均匀分布, 的方差约为 ,确保激活值分布在 区间。
SIREN初始化
SIREN论文提出:使用 的导数性质来指导初始化。
def siren_init_(layer, mode='first'):
"""
SIREN初始化
Args:
layer: 线性层
mode: 'first' 用于第一层,'others' 用于其他层
"""
with torch.no_grad():
if mode == 'first':
# 第一层:w ~ U(-1/√in_dim, 1/√in_dim)
# 覆盖默认的小权重初始化
w_std = 1.0 / layer.in_features
nn.init.uniform_(layer.weight, -w_std, w_std)
else:
# 其他层:w ~ U(-√6/√hidden_dim, √6/√hidden_dim)
# 近似 Xavier初始化,但使用 sin 激活的导数边界
w_std = torch.sqrt(torch.tensor(6.0)) / torch.sqrt(layer.in_features)
nn.init.uniform_(layer.weight, -w_std, w_std)
nn.init.uniform_(layer.bias, -1, 1)
# 应用初始化
def init_siren_weights(model):
"""初始化整个SIREN模型"""
first_layer = True
for layer in model.layers:
if isinstance(layer, nn.Linear):
siren_init_(layer, mode='first' if first_layer else 'others')
first_layer = False初始化效果对比
| 初始化方法 | 第一层权重范围 | 其他层权重范围 | 效果 |
|---|---|---|---|
| Xavier (Tanh) | 低频主导 | ||
| SIREN (第一层) | - | 宽频率响应 | |
| SIREN (其他层) | - | 保持激活 |
傅里叶特征增强
问题:高频表示困难
标准SIREN在初始化良好时能覆盖一定频率范围,但对极高频细节仍显不足。
解决方案:傅里叶特征
在输入端使用傅里叶特征编码:
class FourierFeatures(nn.Module):
"""随机傅里叶特征"""
def __init__(self, in_dim, num_features, include_original=True):
super().__init__()
self.include_original = include_original
# 随机采样频率
B = torch.randn(in_dim, num_features // 2)
self.register_buffer('B', B)
def forward(self, x):
# 投影到高频空间
x_proj = 2 * torch.pi * x @ self.B
# 正弦余弦编码
features = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
if self.include_original:
return torch.cat([x, features], dim=-1)
return features
class HybridSiren(nn.Module):
"""傅里叶特征 + SIREN"""
def __init__(self, in_dim, hidden_dim, out_dim, num_fourier, num_layers):
super().__init__()
self.fourier = FourierFeatures(in_dim, num_fourier)
first_dim = in_dim + num_fourier if self.fourier.include_original else num_fourier
self.layers = nn.ModuleList([
nn.Linear(first_dim, hidden_dim),
*[nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)],
nn.Linear(hidden_dim, out_dim)
])
def forward(self, x):
x = self.fourier(x)
for i, layer in enumerate(self.layers[:-1]):
x = torch.sin(layer(x))
return self.layers[-1](x)可学习频率
class LearnableFourierFeatures(nn.Module):
"""可学习的傅里叶特征"""
def __init__(self, in_dim, num_features, init_scale=1.0):
super().__init__()
# 可学习的频率缩放
self.log_freq_scale = nn.Parameter(torch.ones(in_dim) * torch.log(torch.tensor(init_scale)))
self.phases = nn.Parameter(torch.rand(in_dim, num_features // 2) * 2 * torch.pi)
def forward(self, x):
freq_scale = torch.exp(self.log_freq_scale)
# 缩放输入
x_scaled = x * freq_scale
# 投影(使用可学习的相位)
x_proj = x_scaled.unsqueeze(-1) + self.phases # (..., in_dim, num_bands/2)
features = torch.sin(x_proj).reshape(*x.shape[:-1], -1)
return torch.cat([x, features], dim=-1)应用示例:图像表示
基本用法
import torchvision
from torchvision.transforms import ToTensor
from PIL import Image
# 加载图像
img = Image.open('example.jpg')
img_tensor = ToTensor()(img).permute(1, 2, 0) # (H, W, C)
h, w, c = img_tensor.shape
# 创建坐标网格
y_coords = torch.linspace(0, 1, h)
x_coords = torch.linspace(0, 1, w)
grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing='ij')
coords = torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)
# 创建SIREN
model = Siren(in_dim=2, hidden_dim=256, out_dim=c, num_layers=5)
# 训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
target = img_tensor.reshape(-1, c)
for step in range(5000):
optimizer.zero_grad()
pred = model(coords)
loss = nn.functional.mse_loss(pred, target)
loss.backward()
optimizer.step()
if step % 500 == 0:
print(f"Step {step}, Loss: {loss.item():.6f}")
# 查询新坐标
x_new, y_new = 150.7, 200.3 # 亚像素精度
color = model(torch.tensor([[x_new/w, y_new/h]]))完整训练循环
def train_inr_for_image(model, img_tensor, num_steps=5000, lr=1e-4):
"""训练INR表示图像"""
h, w, c = img_tensor.shape
# 坐标归一化
y_coords = torch.linspace(0, 1, h)
x_coords = torch.linspace(0, 1, w)
grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing='ij')
coords = torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)
target = img_tensor.reshape(-1, c)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)
losses = []
for step in range(num_steps):
optimizer.zero_grad()
pred = model(coords)
loss = nn.functional.mse_loss(pred, target)
loss.backward()
optimizer.step()
scheduler.step()
losses.append(loss.item())
if step % 500 == 0:
print(f"Step {step}, Loss: {loss.item():.6f}, LR: {scheduler.get_last_lr()[0]:.6f}")
return losses应用示例:Poisson方程求解
SIREN的导数保持性质使其特别适合物理问题。
问题定义
求解Poisson方程:
其中 是已知源项, 是待求解的势函数。
SIREN求解器
class PoissonSolver(nn.Module):
"""基于SIREN的Poisson方程求解器"""
def __init__(self, hidden_dim=128, num_layers=4):
super().__init__()
self.siren = Siren(in_dim=2, hidden_dim=hidden_dim,
out_dim=1, num_layers=num_layers)
def forward(self, x):
return self.siren(x)
def compute_laplacian(self, x):
"""自动计算Laplacian"""
x.requires_grad_(True)
u = self.forward(x)
# 计算一阶导数
grad_u = torch.autograd.grad(
u, x, grad_outputs=torch.ones_like(u),
create_graph=True
)[0]
# 计算二阶导数(Laplacian)
u_xx = torch.autograd.grad(
grad_u[:, 0], x, grad_outputs=torch.ones_like(grad_u[:, 0]),
create_graph=True
)[0][:, 0]
u_yy = torch.autograd.grad(
grad_u[:, 1], x, grad_outputs=torch.ones_like(grad_u[:, 1]),
create_graph=True
)[0][:, 1]
return u_xx + u_yy
def physics_loss(self, x, f):
"""Poisson方程残差:∇²u - f = 0"""
laplacian = self.compute_laplacian(x)
return nn.functional.mse_loss(laplacian, f)
def boundary_loss(self, x_boundary, u_boundary):
"""边界条件损失"""
pred = self.forward(x_boundary)
return nn.functional.mse_loss(pred, u_boundary)
def solve_poisson_with_siren(f_func, boundary_func, domain_bounds, num_points=1000):
"""求解Poisson方程"""
# 创建训练数据
x_interior = torch.rand(num_points, 2) * (domain_bounds[1] - domain_bounds[0]) + domain_bounds[0]
f_values = f_func(x_interior)
# 边界采样(简化处理)
x_boundary = torch.rand(100, 2)
u_boundary = boundary_func(x_boundary)
model = PoissonSolver()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(5000):
optimizer.zero_grad()
# 物理损失
physics = model.physics_loss(x_interior, f_values)
# 边界损失
boundary = model.boundary_loss(x_boundary, u_boundary)
loss = physics + boundary
loss.backward()
optimizer.step()
if step % 500 == 0:
print(f"Step {step}, Physics: {physics.item():.6f}, Boundary: {boundary.item():.6f}")
return model与其他隐式表示方法的对比
| 方法 | 激活函数 | 高频能力 | 导数保持 | 训练稳定性 |
|---|---|---|---|---|
| ReLU-MLP | ReLU | 差 | 否 | 好 |
| Tanh-MLP | Tanh | 中 | 否 | 中 |
| GELU-MLP | GELU | 中 | 否 | 中 |
| SIREN | Sin | 好 | 是 | 中(需特殊初始化) |
| Gaussians | 高斯 | 好 | 否 | 好 |
代码资源
官方实现
GitHub仓库:vsitzmann/siren
核心代码片段
# 官方SIREN实现的核心逻辑
def siren_init(model):
"""官方推荐的SIREN初始化"""
def _init_weights(m):
if isinstance(m, nn.Linear):
if m.out_features == model.final_linear.out_features:
nn.init.uniform_(m.weight, -1/len(model.net), 1/len(model.net))
else:
nn.init.uniform_(m.weight, -torch.sqrt(torch.tensor(6.))/len(model.net),
torch.sqrt(torch.tensor(6.))/len(model.net))
if m.bias is not None:
nn.init.uniform_(m.bias, -1, 1)
model.apply(_init_weights)参考文献
相关链接:implicit-neural-representations | nerf-neural-radiance-field | neural-odes
Footnotes
-
Sitzmann, V., Martel, J., Bergman, A., Lindell, D., & Wetzstein, G. (2020). Implicit Neural Representations with Periodic Activation Functions. Advances in Neural Information Processing Systems, 33. ↩