Fourier Neural Operator (FNO) 详解

1. 引言

Fourier Neural Operator (FNO) 是由Li等人于2020年提出的一种神经算子架构1。其核心创新在于利用傅里叶变换实现全局积分算子,从而高效学习函数空间之间的映射。FNO在计算流体力学、天气预报等领域展现了卓越的性能。


2. 核心思想

2.1 傅里叶变换与积分算子

考虑积分算子:

其中 是核函数。

关键观察:在傅里叶域中,卷积核的作用变得简单:

2.2 频域线性算子

表示傅里叶变换,则:

这意味着频域乘法等价于空域卷积


3. FNO架构

3.1 整体结构

输入函数 a(x) ∈ ℝⁿᵛˣᴺ
    ↓
特征提升 (Lift): a₀(x) = σ(W·a(x) + b)
    ↓
L层 Fourier层:
    ├─ 傅里叶变换
    ├─ 线性变换 (频域)
    ├─ 逆傅里叶变换
    └─ 局部激活
    ↓
投影 (Project): u(x) = W·a_L(x) + b
    ↓
输出函数 u(x) ∈ ℝⁿᵘˣᴺ

3.2 傅里叶层详解

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
 
class FNO2d(nn.Module):
    def __init__(
        self,
        modes1: int,        # 第一个维度的傅里叶模式数
        modes2: int,        # 第二个维度的傅里叶模式数
        width: int,         # 通道维度
        n_layers: int = 4   # Fourier层数量
    ):
        super().__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        
        # 输入特征提升:输入维度 -> width
        self.fc0 = nn.Linear(in_features=1, out_features=width)
        
        # Fourier层
        self.fourier_layers = nn.ModuleList([
            SpectralConv2d(modes1, modes2, width)
            for _ in range(n_layers)
        ])
        
        # 逐点变换
        self.pointwise = nn.ModuleList([
            nn.Sequential(
                nn.Linear(width, width),
                nn.GELU()
            )
            for _ in range(n_layers)
        ])
        
        # 输出投影
        self.fc1 = nn.Linear(width, 128)
        self.fc2 = nn.Linear(128, 1)
    
    def forward(self, x):
        # x: (batch, x_dim, y_dim, 1) - 输入函数
        x = self.fc0(x)  # (batch, x_dim, y_dim, width)
        
        for fourier_layer, pointwise in zip(
            self.fourier_layers, self.pointwise
        ):
            x = fourier_layer(x)  # 频域变换
            x = pointwise(x)      # 逐点非线性
        
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        
        return x
 
class SpectralConv2d(nn.Module):
    """
    频谱卷积层:二维FNO的核心组件
    """
    def __init__(self, modes1, modes2, width):
        super().__init__()
        self.modes1 = modes1  # 第一个维度保留的模式数
        self.modes2 = modes2  # 第二个维度保留的模式数
        self.width = width
        
        # 可学习的频谱权重
        # 复数权重:(2, modes1, modes2, width, width)
        # 2: 实部和虚部
        self.weights1 = nn.Parameter(
            torch.randn(2, modes1, modes2, width, width, dtype=torch.float32)
        )
        self.weights2 = nn.Parameter(
            torch.randn(2, modes1, modes2, width, width, dtype=torch.float32)
        )
    
    def compl_mul2d(self, input, weights):
        """复数矩阵乘法"""
        # input: (batch, x_modes, y_modes, width)
        # weights: (2, x_modes, y_modes, width, width) -> (x_modes, y_modes, width, width) complex
        weights_complex = torch.complex(weights[0], weights[1])
        
        # 复数乘法并求和
        return torch.einsum('bxyi,ijxy->bxyj', input, weights_complex)
    
    def forward(self, x):
        # x: (batch, x_dim, y_dim, width)
        batch_size = x.shape[0]
        
        # 傅里叶变换
        x_ft = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        
        # 频谱裁剪:只保留低频模式
        # x_ft: (batch, x_dim//2+1, y_dim//2+1, width)
        x_ft = x_ft[:, :self.modes1, :self.modes2, :]
        
        # 频域变换
        out_ft = self.compl_mul2d(x_ft, self.weights1)
        
        # 填充回原始大小
        x_ft_full = torch.zeros(
            batch_size, x.shape[1]//2+1, x.shape[2]//2+1, self.width,
            dtype=x.dtype, device=x.device
        )
        x_ft_full[:, :self.modes1, :self.modes2, :] = out_ft
        
        # 逆变换
        x = torch.fft.irfft2(x_ft_full, s=(x.shape[1], x.shape[2]), norm='ortho')
        
        return x

4. 数学推导

4.1 傅里叶层的前向传播

设输入特征 ,Fourier层输出:

其中:

  • :傅里叶变换
  • :可学习的频谱响应函数
  • :非线性激活

4.2 频谱参数化

全局频谱核

对于每个频率 ,有一个独立的权重矩阵。

低频近似:由于湍流等物理现象的能量主要集中在低频,保留前 个模式:

4.3 计算复杂度

操作复杂度
傅里叶变换
频谱乘法
逆变换
总前向传播

其中 是空间点数, 是层数, 是模式数。


5. 变体与扩展

5.1 AdaptFNO (NeurIPS ML4PS 2025)

核心创新:动态调整频谱模式数,基于输入频率内容自适应。

class AdaptFNO(nn.Module):
    def __init__(self, base_modes, width, n_layers):
        super().__init__()
        self.base_modes = base_modes
        
        # 频率选择器
        self.freq_selector = nn.ModuleList([
            FrequencySelector(base_modes)
            for _ in range(n_layers)
        ])
        
        # 自适应Fourier层
        self.fourier_layers = nn.ModuleList([
            AdaptiveSpectralConv(base_modes, width)
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        # 选择频率模式
        freq_weights = self.freq_selector(x)
        
        for selector, layer in zip(self.freq_weights, self.fourier_layers):
            x = layer(x, selector)
        
        return x
 
class FrequencySelector(nn.Module):
    """基于交叉注意力的频率选择"""
    def __init__(self, max_modes):
        super().__init__()
        self.query = nn.Linear(1, max_modes)
        self.key = nn.Linear(max_modes, max_modes)
        self.value = nn.Linear(max_modes, max_modes)
    
    def forward(self, x):
        # x: (batch, N, 1)
        # 计算频率重要性
        freq = torch.arange(x.shape[1], device=x.device).float() / x.shape[1]
        freq = freq.unsqueeze(0).unsqueeze(-1)  # (1, N, 1)
        
        q = self.query(freq)
        k = self.key(q)
        v = self.value(q)
        
        # 交叉注意力
        attn = torch.softmax(q @ k.transpose(-2, -1) / np.sqrt(self.max_modes), dim=-1)
        return attn @ v

5.2 多尺度FNO

处理多尺度物理现象:

class MultiScaleFNO(nn.Module):
    def __init__(self, modes_per_scale, width, n_scales):
        super().__init__()
        self.n_scales = n_scales
        
        # 每个尺度独立的FNO
        self.scale_fnns = nn.ModuleList([
            FNO2d(modes, width)
            for modes in modes_per_scale
        ])
        
        # 尺度融合
        self.fusion = nn.Sequential(
            nn.Linear(width * n_scales, width),
            nn.GELU(),
            nn.Linear(width, width)
        )
    
    def forward(self, x):
        scale_outputs = []
        for fnn in self.scale_fnns:
            scale_outputs.append(fnn(x))
        
        # 尺度融合
        fused = torch.cat(scale_outputs, dim=-1)
        return self.fusion(fused)

5.3 时间-空间FNO

处理时间依赖问题:

class TimeSpaceFNO(nn.Module):
    def __init__(self, modes_x, modes_t, width, n_layers):
        super().__init__()
        # 空间Fourier层
        self.spatial_layers = nn.ModuleList([
            SpectralConv2d(modes_x, modes_x, width)
            for _ in range(n_layers)
        ])
        
        # 时间Fourier层
        self.temporal_layers = nn.ModuleList([
            SpectralConv1d(modes_t, width)
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        # x: (batch, T, X, Y, channels)
        for i in range(self.n_layers):
            # 空间变换
            x = self.spatial_layers[i](x)
            
            # 时间变换
            x = x.permute(0, 2, 3, 4, 1)  # (B, X, Y, C, T)
            x = self.temporal_layers[i](x)
            x = x.permute(0, 4, 1, 2, 3)  # (B, T, X, Y, C)
        
        return x

6. 实践指南

6.1 模式数选择

问题类型推荐模式数理由
简单PDE16-32解光滑
湍流48-64高频分量重要
多尺度32-48需平衡
高分辨率32-64计算效率

6.2 训练技巧

# 学习率调度
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2
)
 
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
# 混合精度
scaler = torch.cuda.amp.GradScaler()

6.3 常见问题

问题原因解决方案
边界不连续周期假设过强使用非周期FNO变体
高频欠拟合模式数不足增加模式数或使用多尺度
训练不稳定学习率过高降低学习率或使用warmup

7. 应用案例

7.1 Darcy流

二维稳态Darcy方程:

任务:从渗透率场 预测速度场

# Darcy流实验配置
config = {
    'modes1': 20,
    'modes2': 20,
    'width': 32,
    'n_layers': 4,
    'grid_size': 128,
    'batch_size': 32,
    'epochs': 100
}

7.2 Navier-Stokes方程

二维湍流N-S方程:

任务:预测未来时刻的流场。

class NavierStokesFNO(TimeSpaceFNO):
    def __init__(self):
        super().__init__(
            modes_x=64, modes_t=20,
            width=64, n_layers=6
        )
        self.output_dim = 1  # 预测单个时间步
    
    def forward(self, u0, t_steps):
        # u0: 初始条件
        # t_steps: 目标时间
        x = self.encode(u0)
        
        for t in t_steps:
            x = self.step(x, t)
        
        return self.decode(x)

8. 性能基准

数据集方法RelL2误差
Darcy FlowFNO0.0147
Darcy FlowU-Net0.0312
Navier-Stokes (Vorticity)FNO0.1558
Navier-Stokes (Vorticity)ResNet0.2847

9. 参考文献


相关主题

Footnotes

  1. Li, Z., et al. (2020). Fourier neural operator for parametric partial differential equations. arXiv:2010.08895.