DeepONet 深度解析

1. 引言

DeepONet (Deep Operator Network) 由Lu等人于2019年提出1,是一种学习非线性算子的深度学习框架。与FNO不同,DeepONet采用分支-主干(Branch-Trunk)分解,能够学习任意函数空间之间的映射,而不依赖于特定的核结构。


2. 核心思想

2.1 算子表示定理

定理(Universal Operator Approximation):对于任意连续算子 ,有:

其中:

  • :分支网络输出的特征
  • :主干网络输出的基函数

2.2 分支-主干分解

输入函数 a(x)         输出位置 y
    ↓                     ↓
分支网络              主干网络
    ↓                     ↓
特征向量 B ∈ ℝᵖ    基函数 T ∈ ℝᵖ
    ↓                     ↓
        \                 /
         ↓               ↓
          內积: Σᵢ Bᵢ·Tᵢ(y)
                    ↓
               输出值 u(y)

3. 数学框架

3.1 问题形式化

给定训练数据 ,其中:

  • :输入函数
  • :输出函数

目标:学习 使得:

3.2 网络输出

3.3 损失函数


4. DeepONet架构

4.1 标准架构

import torch
import torch.nn as nn
import numpy as np
 
class DeepONet(nn.Module):
    def __init__(
        self,
        branch_layers: list,   # 分支网络层结构
        trunk_layers: list,   # 主干网络层结构
        p: int = 128,        # 特征维度
        activation: str = 'relu'
    ):
        super().__init__()
        self.p = p
        
        # 激活函数
        self.activation = self._get_activation(activation)
        
        # 分支网络:处理输入函数
        branch_dims = [1] + branch_layers + [p]
        self.branch_net = self._build_mlp(branch_dims)
        
        # 主干网络:处理输出位置
        trunk_dims = [1] + trunk_layers + [p]
        self.trunk_net = self._build_mlp(trunk_dims)
        
        # 可选的偏置网络
        self.bias_net = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        ) if True else None
    
    def _get_activation(self, name):
        activations = {
            'relu': nn.ReLU,
            'tanh': nn.Tanh,
            'gelu': nn.GELU,
            'silu': nn.SiLU
        }
        return activations.get(name, nn.ReLU)
    
    def _build_mlp(self, dims):
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                layers.append(self.activation())
        return nn.Sequential(*layers)
    
    def forward(self, a, y):
        """
        参数:
            a: (batch, n_input) - 输入函数在采样点的值
            y: (batch, n_output, 1) - 输出位置坐标
        返回:
            u: (batch, n_output) - 输出函数值
        """
        # 分支输出
        b = self.branch_net(a)  # (batch, p)
        
        # 主干输出
        t = self.trunk_net(y)   # (batch, n_output, p)
        
        # 算子应用:內积
        # b: (batch, p) -> (batch, 1, p)
        # t: (batch, n_output, p)
        u = torch.einsum('bp,bop->bo', b, t)  # (batch, n_output)
        
        # 添加偏置
        if self.bias_net is not None:
            bias = self.bias_net(y).squeeze(-1)  # (batch, n_output)
            u = u + bias
        
        return u

4.2 堆叠DeepONet

增加网络容量:

class StackedDeepONet(nn.Module):
    """
    堆叠DeepONet:多个分支-主干模块串联
    """
    def __init__(self, n_stacks, branch_dims, trunk_dims, p):
        super().__init__()
        
        self.stacks = nn.ModuleList([
            DeepONetBranchTrunk(branch_dims, trunk_dims, p)
            for _ in range(n_stacks)
        ])
        
        # 层归一化
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(p)
            for _ in range(n_stacks - 1)
        ])
    
    def forward(self, a, y):
        x = None
        for i, stack in enumerate(self.stacks):
            out = stack(a if x is None else x, y)
            if i < len(self.stacks) - 1:
                x = self.layer_norms[i](x + out)
        return out

5. 变体与扩展

5.1 随机特征DeepONet

使用随机基函数提高效率:

class RandomFeatureDeepONet(nn.Module):
    """
    使用随机特征的DeepONet变体
    理论支持:Random Feature Approximation
    """
    def __init__(self, p, sigma=1.0):
        super().__init__()
        self.p = p
        
        # 随机傅里叶特征
        self.register_buffer('omega', torch.randn(1, p) * sigma)
        self.register_buffer('bias', torch.rand(1, p) * 2 * np.pi)
        
        # 可学习的分支权重
        self.branch_weights = nn.Linear(1, p, bias=False)
        
        # 输出层
        self.output_layer = nn.Linear(p, 1, bias=True)
    
    def forward(self, a, y):
        # 分支:随机特征
        # a: (batch, n_input)
        a_feat = torch.mean(a, dim=1, keepdim=True)  # (batch, 1)
        b = self.branch_weights(a_feat) * torch.cos(
            self.omega * a_feat + self.bias
        )  # (batch, p)
        
        # 主干:随机特征
        # y: (batch, n_output, 1)
        t = torch.cos(
            torch.matmul(y, self.omega) + self.bias
        )  # (batch, n_output, p)
        
        # 內积
        u = torch.einsum('bp,bop->bo', b, t)
        
        return self.output_layer(u.unsqueeze(-1)).squeeze(-1)

5.2 物理信息DeepONet (PI-DeepONet)

结合物理约束:

class PhysicsInformedDeepONet(nn.Module):
    def __init__(self, branch_dims, trunk_dims, p, pde_fn):
        super().__init__()
        self.deeponet = DeepONet(branch_dims, trunk_dims, p)
        self.pde_fn = pde_fn
    
    def forward(self, a, y, enforce_pde=True):
        u = self.deeponet(a, y)
        
        if enforce_pde and self.training:
            # 计算PDE损失
            y.requires_grad_(True)
            u_pred = self.deeponet(a, y)
            
            # 自动微分计算导数
            u_x = torch.autograd.grad(
                u_pred.sum(), y, create_graph=True
            )[0]
            
            # PDE残差
            residual = self.pde_fn(u_pred, u_x)
            
            return u, residual
        
        return u
    
    def loss(self, a, y, u_true, a_pde=None, y_pde=None, lambda_pde=1.0):
        # 数据损失
        u_pred = self.deeponet(a, y)
        loss_data = nn.MSELoss()(u_pred, u_true)
        
        if a_pde is not None and y_pde is not None:
            # 物理损失
            _, residual = self.forward(a_pde, y_pde, enforce_pde=True)
            loss_pde = torch.mean(residual**2)
            return loss_data + lambda_pde * loss_pde
        
        return loss_data

5.3 多算子DeepONet

学习多个算子:

class MultiOperatorDeepONet(nn.Module):
    """
    同时学习多个算子
    应用:多任务学习、算子聚类
    """
    def __init__(self, n_operators, branch_dims, trunk_dims, p):
        super().__init__()
        self.n_operators = n_operators
        
        # 共享主干网络
        self.shared_trunk = DeepONetTrunk(trunk_dims, p)
        
        # 每个算子独立的分支网络
        self.operator_branches = nn.ModuleList([
            DeepONetBranch(branch_dims, p)
            for _ in range(n_operators)
        ])
    
    def forward(self, a_list, y, operator_idx=None):
        """
        a_list: n_operators个输入函数列表
        operator_idx: 指定算子索引,None则输出所有
        """
        outputs = []
        for i, branch in enumerate(self.operator_branches):
            if operator_idx is None or operator_idx == i:
                b = branch(a_list[i])
                t = self.shared_trunk(y)
                u = torch.einsum('bp,bop->bo', b, t)
                outputs.append(u)
        
        if operator_idx is not None:
            return outputs[0]
        return torch.stack(outputs, dim=-1)  # (batch, n_output, n_operators)

5.4 集成DeepONet

提高预测稳定性:

class EnsembleDeepONet(nn.Module):
    """
    集成多个DeepONet,提高泛化能力和不确定性估计
    """
    def __init__(self, n_models, branch_dims, trunk_dims, p):
        super().__init__()
        self.n_models = n_models
        
        self.models = nn.ModuleList([
            DeepONet(branch_dims, trunk_dims, p)
            for _ in range(n_models)
        ])
    
    def forward(self, a, y, return_uncertainty=False):
        predictions = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                pred = model(a, y)
                predictions.append(pred)
        
        preds = torch.stack(predictions, dim=0)  # (n_models, batch, n_output)
        mean = preds.mean(dim=0)
        
        if return_uncertainty:
            std = preds.std(dim=0)
            return mean, std
        
        return mean

6. 理论分析

6.1 逼近误差

定理(DeepONet逼近误差)2:设 是紧域 上的连续算子, 是具有 个特征的DeepONet,则:

其中:

  • :网络逼近误差
  • :优化误差

6.2 复杂度分析

组件参数数量计算复杂度
分支网络
主干网络
总前向传播-

6.3 表达能力的几何解释

DeepONet的输出可以视为在主干基函数展开下的系数:

这等价于将输入函数 映射到特征空间 ,然后在该空间与输出位置 的表示 做内积。


7. 实践技巧

7.1 传感器点采样

DeepONet需要输入函数在特定点的值:

def sample_sensor_points(n_points, domain='interval'):
    """采样传感器点位置"""
    if domain == 'interval':
        # 一维区间
        return torch.linspace(0, 1, n_points).unsqueeze(1)
    elif domain == 'square':
        # 二维方形
        x = torch.linspace(0, 1, int(np.sqrt(n_points)))
        y = torch.linspace(0, 1, int(np.sqrt(n_points)))
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        return torch.stack([xx.flatten(), yy.flatten()], dim=1)
    elif domain == 'random':
        # 随机采样
        return torch.rand(n_points, 2)

7.2 训练策略

def train_deeponet(
    model, train_loader, val_loader,
    epochs=500, lr=1e-3, lambda_pde=0.0
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        for batch in train_loader:
            a, y, u_true = batch
            
            optimizer.zero_grad()
            
            if lambda_pde > 0:
                u_pred, residual = model(a, y, enforce_pde=True)
                loss = nn.MSELoss()(u_pred, u_true) + lambda_pde * torch.mean(residual**2)
            else:
                u_pred = model(a, y)
                loss = nn.MSELoss()(u_pred, u_true)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        scheduler.step()
        
        # 验证
        if epoch % 50 == 0:
            val_loss = validate(model, val_loader)
            print(f"Epoch {epoch}: Train={train_loss/len(train_loader):.6f}, Val={val_loss:.6f}")

7.3 评估指标

def evaluate_deeponet(model, test_loader):
    model.eval()
    
    relative_errors = []
    max_errors = []
    
    with torch.no_grad():
        for a, y, u_true in test_loader:
            u_pred = model(a, y)
            
            # 相对L2误差
            rel_l2 = torch.norm(u_pred - u_true) / torch.norm(u_true)
            relative_errors.append(rel_l2.item())
            
            # 最大绝对误差
            max_ae = torch.max(torch.abs(u_pred - u_true))
            max_errors.append(max_ae.item())
    
    return {
        'mean_rel_l2': np.mean(relative_errors),
        'std_rel_l2': np.std(relative_errors),
        'mean_max_ae': np.mean(max_errors),
        'std_max_ae': np.std(max_errors)
    }

8. 应用场景

8.1 稳态PDE求解

  • Darcy方程
  • Poisson方程
  • Eikonal方程

8.2 时间依赖PDE

  • 热方程
  • 波方程
  • Burgers方程

8.3 反问题

  • 参数识别
  • 边界条件推断
  • 源项估计

8.4 控制与优化

  • 最优控制
  • 模型预测控制

9. 与FNO的比较

特性DeepONetFNO
核结构可学习傅里叶基
输入处理离散传感器连续函数
计算效率
边界处理灵活周期假设
理论基础算子展开定理卷积定理

10. 参考文献


相关主题

Footnotes

  1. Lu, L., et al. (2021). Learning nonlinear operators via DeepONet. Nature Machine Intelligence, 3(3), 218-226.

  2. Lanthaler, S., et al. (2022). Error estimates for DeepONets: A deep learning framework for infinite-dimensional operators. Numerische Mathematik.