概述

KAN 2.0(Kolmogorov-Arnold Networks Meet Science)是对原始KAN的重大升级,专注于科学发现应用。1 MIT团队在2024年8月发布了KAN 2.0,提出了三大核心功能:符号公式发现特征重要性分析模块化设计。本文深入解析KAN 2.0的核心原理及其在科学研究中的应用价值。


1. KAN 2.0 vs KAN 1.0:关键改进

核心差异对比

特性KAN 1.0KAN 2.0
激活函数B-样条B-样条 + 符号函数 + 专家混合
可解释性激活函数可视化符号公式提取
模块化单一网络模块化、层级化设计
训练标准反向传播渐进式训练 + 剪枝
科学应用概念验证系统化应用框架

KAN 1.0 回顾

原始KAN将Kolmogorov-Arnold表示定理应用于神经网络,把激活函数从节点移到边上:

其中 是可学习的激活函数(边上的B-样条)。


2. KAN 2.0 的三大核心功能

2.1 符号公式发现

KAN 2.0的核心创新是能够自动提取符号公式。这使得KAN不仅能拟合数据,还能揭示潜在的物理规律。

符号激活函数

class SymbolicActivation(nn.Module):
    """
    符号激活函数库
    
    KAN 2.0 引入了多种可解释的符号函数
    """
    def __init__(self):
        super().__init__()
        
        # 预定义的符号函数
        self.symbolic_functions = {
            'sin': torch.sin,
            'cos': torch.cos,
            'exp': torch.exp,
            'log': torch.log,
            'sqrt': torch.sqrt,
            'abs': torch.abs,
            'square': lambda x: x ** 2,
            'inverse': lambda x: 1 / (x + 1e-8),
        }
        
        # 可学习的符号激活
        self.learnable_symbolic = nn.Parameter(torch.ones(1))
    
    def forward(self, x, func_name='sin'):
        """应用符号函数"""
        if func_name in self.symbolic_functions:
            return self.symbolic_functions[func_name](x)
        else:
            # 默认为自己
            return x
 
 
class SymbolicKANLayer(nn.Module):
    """
    符号 KAN Layer
    
    结合 B-样条和符号函数的混合激活
    """
    def __init__(self, in_features, out_features, 
                 symbolic_funcs=['sin', 'cos', 'exp', 'log'],
                 grid_size=5, spline_order=3):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        
        # B-样条激活
        self.spline = BSplineActivation(in_features, out_features, 
                                        grid_size, spline_order)
        
        # 符号激活
        self.symbolic = SymbolicActivation()
        self.symbolic_weight = nn.Parameter(torch.zeros(len(symbolic_funcs), 
                                                       out_features, in_features))
        
        # 符号函数选择器
        self.func_names = symbolic_funcs
        
        # 基础激活
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
    
    def forward(self, x):
        # B-样条部分
        spline_out = self.spline(x)
        
        # 符号激活部分
        symbolic_out = 0
        for i, func_name in enumerate(self.func_names):
            func_out = self.symbolic(x, func_name)
            symbolic_out += torch.einsum('bi,oi->bo', 
                                        func_out * self.symbolic_weight[i], 
                                        torch.ones(self.out_features))
        
        # 基础激活
        base_out = torch.einsum('bi,oi->bo', x, 
                                torch.nn.functional.silu(self.base_weight))
        
        return spline_out + symbolic_out + base_out

公式提取算法

def extract_symbolic_formula(model, input_names, output_name='f', 
                             threshold=0.01, top_k=5):
    """
    从训练好的 KAN 中提取符号公式
    
    Args:
        model: 训练好的 KAN 模型
        input_names: 输入变量名列表
        output_name: 输出变量名
        threshold: 系数阈值
        top_k: 每层保留的 top-k 激活
    
    Returns:
        symbolic_expression: SymPy 符号表达式
    """
    from sympy import symbols, Function, sin, cos, exp, log, sqrt
    
    # 定义符号
    x = [symbols(name) for name in input_names]
    
    expression = 0
    
    # 遍历每一层
    for layer_idx, layer in enumerate(model.layers):
        layer_expr = 0
        
        # 获取激活函数系数
        coeff = layer.coeff.data.abs().mean(dim=0)  # (in, out)
        
        # 对每个输出单元
        for out_idx in range(layer.out_features):
            out_expr = 0
            
            # 获取最强的输入连接
            strengths = coeff[:, out_idx]
            top_indices = torch.topk(strengths, min(top_k, strengths.numel())).indices
            
            for in_idx in top_indices:
                # 确定激活函数类型(通过可视化或拟合)
                func_type = identify_activation_function(
                    layer.coeff.data[out_idx, in_idx]
                )
                
                # 构建符号表达式
                if func_type == 'linear':
                    coef = layer.coeff.data[out_idx, in_idx].mean().item()
                    out_expr += coef * x[in_idx]
                elif func_type == 'sin':
                    out_expr += symbols(f'a_{layer_idx}_{in_idx}_{out_idx}') * sin(x[in_idx])
                elif func_type == 'exp':
                    out_expr += symbols(f'b_{layer_idx}_{in_idx}_{out_idx}') * exp(x[in_idx])
                # ... 其他函数类型
            
            layer_expr += out_expr
        
        expression = expression + layer_expr
    
    return expression
 
 
def identify_activation_function(coeff, grid):
    """
    识别激活函数的类型
    
    通过拟合不同函数来确定类型
    """
    import numpy as np
    from scipy.optimize import curve_fit
    
    # 定义候选函数
    candidates = {
        'linear': lambda x, a, b: a * x + b,
        'sin': lambda x, a, b, c: a * np.sin(b * x + c),
        'cos': lambda x, a, b, c: a * np.cos(b * x + c),
        'exp': lambda x, a, b: a * np.exp(b * x),
        'log': lambda x, a, b: a * np.log(np.abs(x) + 1e-8) + b,
    }
    
    # 简化的识别方法
    x = np.linspace(0, 1, len(coeff))
    y = coeff.numpy()
    
    # 计算拟合误差
    best_func = 'linear'
    best_error = float('inf')
    
    for name, func in candidates.items():
        try:
            popt, _ = curve_fit(func, x, y, maxfev=5000)
            y_pred = func(x, *popt)
            error = np.mean((y - y_pred) ** 2)
            
            if error < best_error:
                best_error = error
                best_func = name
        except:
            continue
    
    return best_func

2.2 特征重要性分析

KAN 2.0提供了系统化的特征重要性评估方法,帮助理解哪些输入特征对输出最重要。

跳跃(Jumping)机制

class JumpKANLayer(nn.Module):
    """
    带跳跃连接的 KAN Layer
    
    允许信息直接从输入跳到输出层,提高效率
    """
    def __init__(self, in_features, out_features, grid_size=5, 
                 spline_order=3, jump_strength=0.1):
        super().__init__()
        
        # 标准 KAN 激活
        self.kan_activation = KANLayer(in_features, out_features, 
                                       grid_size, spline_order)
        
        # 跳跃连接:直接从输入到输出
        self.jump_weight = nn.Parameter(
            torch.randn(out_features, in_features) * jump_strength
        )
        
        # 跳跃连接的重要性权重
        self.jump_importance = nn.Parameter(torch.ones(in_features))
    
    def forward(self, x):
        # 标准 KAN 激活
        kan_out = self.kan_activation(x)
        
        # 跳跃连接(可解释的线性贡献)
        jump_out = torch.einsum('bi,oi,oi->bo', 
                                x, self.jump_weight, self.jump_importance)
        
        return kan_out + jump_out
 
 
def compute_feature_importance(model, x, y_true=None):
    """
    计算输入特征的重要性
    
    基于跳跃连接权重和激活强度
    """
    importance_scores = []
    
    with torch.no_grad():
        for name, module in model.named_modules():
            if hasattr(module, 'jump_weight'):
                # 跳跃连接权重表示直接贡献
                jump_imp = module.jump_weight.data.abs().mean(dim=0)
                jump_imp = jump_imp * module.jump_importance.data
                importance_scores.append(jump_imp)
            
            if hasattr(module, 'coeff'):
                # B-样条系数表示非线性贡献
                spline_imp = module.coeff.data.abs().mean(dim=(0, 2))
                importance_scores.append(spline_imp)
    
    # 聚合所有层的贡献
    total_importance = torch.zeros_like(importance_scores[0])
    for imp in importance_scores:
        total_importance = total_importance + imp
    
    # 归一化
    total_importance = total_importance / total_importance.sum()
    
    return total_importance.numpy()

2.3 模块化设计

KAN 2.0引入了模块化概念,允许组合多个KAN子网络形成更大的系统。

class KANModule(nn.Module):
    """
    KAN 模块
    
    一个可重复使用的 KAN 子模块
    """
    def __init__(self, in_features, hidden_features, out_features,
                 num_layers=2, grid_size=5, spline_order=3):
        super().__init__()
        
        # 构建模块内部结构
        layer_dims = [in_features] + [hidden_features] * num_layers + [out_features]
        self.kan = KAN(layer_dims, grid_size, spline_order)
        
        # 模块的输入/输出名
        self.input_names = None
        self.output_names = None
    
    def set_names(self, input_names, output_names):
        """设置模块的输入输出名称"""
        self.input_names = input_names
        self.output_names = output_names
    
    def forward(self, x):
        return self.kan(x)
 
 
class ModularKAN(nn.Module):
    """
    模块化 KAN
    
    由多个 KANModule 组成的层级系统
    """
    def __init__(self):
        super().__init__()
        
        # 定义模块
        self.module_A = KANModule(in_features=2, hidden_features=5, 
                                  out_features=3, num_layers=2)
        self.module_B = KANModule(in_features=3, hidden_features=5, 
                                  out_features=1, num_layers=2)
        
        # 模块间的连接
        self.connection_A_to_B = nn.Parameter(torch.randn(3, 3))
        
        # 设置模块名称
        self.module_A.set_names(['x1', 'x2'], ['h1', 'h2', 'h3'])
        self.module_B.set_names(['h1', 'h2', 'h3'], ['y'])
    
    def forward(self, x):
        # 模块 A
        h = self.module_A(x)
        
        # 模块 B
        y = self.module_B(h)
        
        return y
    
    def get_module_graph(self):
        """
        获取模块依赖图(用于可视化和分析)
        """
        return {
            'modules': ['module_A', 'module_B'],
            'connections': [
                ('module_A', 'module_B', self.connection_A_to_B.shape)
            ],
            'input_names': self.module_A.input_names,
            'output_names': self.module_B.output_names
        }

3. 科学发现应用案例

3.1 物理定律发现

KAN 2.0在发现物理定律方面展现了强大能力。MIT团队展示了KAN如何从数据中恢复出已知物理公式。

def discover_physics_laws():
    """
    示例:使用 KAN 发现物理定律
    
    假设我们要发现开普勒第三定律:T² ∝ a³
    """
    import torch
    import numpy as np
    
    # 生成数据:行星轨道周期和半长轴
    np.random.seed(42)
    a = np.random.uniform(0.5, 10, 100)  # 半长轴 (AU)
    # 开普勒第三定律: T² = a³ (归一化常数)
    T = np.sqrt(a ** 3) + np.random.normal(0, 0.1, 100)  # 周期
    
    # 准备数据
    x_train = torch.tensor(a.reshape(-1, 1), dtype=torch.float32)
    y_train = torch.tensor(T.reshape(-1, 1), dtype=torch.float32)
    
    # 标准化
    x_mean, x_std = x_train.mean(), x_train.std()
    y_mean, y_std = y_train.mean(), y_train.std()
    x_train = (x_train - x_mean) / x_std
    y_train = (y_train - y_mean) / y_std
    
    # 构建 KAN
    model = KAN([1, 5, 5, 1], grid_size=5, spline_order=3)
    
    # 训练
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(500):
        optimizer.zero_grad()
        pred = model(x_train)
        loss = F.mse_loss(pred, y_train)
        loss.backward()
        optimizer.step()
    
    # 提取公式
    formula = extract_symbolic_formula(
        model, 
        input_names=['a'],
        output_name='T'
    )
    
    print(f"发现的公式: T = {formula}")
    # 预期: T = a^(3/2) 或类似形式
 
 
def discover_lorentz_force():
    """
    示例:发现洛伦兹力
    
    F = q(E + v × B)
    """
    # 生成合成数据
    q = torch.randn(1000)  # 电荷
    E = torch.randn(1000, 3)  # 电场
    v = torch.randn(1000, 3)  # 速度
    B = torch.randn(1000, 3)  # 磁场
    
    # 计算力(简化模型)
    F = q.unsqueeze(-1) * (E + torch.cross(v, B))
    
    # 构建 KAN 模型
    model = KAN([10, 20, 20, 3], grid_size=5, spline_order=3)
    
    # 训练并提取公式
    # ...

3.2 模块化科学系统

class ScientificKAN(nn.Module):
    """
    科学应用的模块化 KAN 系统
    
    用于组合多个物理子系统
    """
    def __init__(self):
        super().__init__()
        
        # 动力学模块
        self.dynamics = KANModule(
            in_features=4,  # [x, y, vx, vy]
            hidden_features=10,
            out_features=2,  # [ax, ay]
            num_layers=3
        )
        self.dynamics.set_names(
            ['x', 'y', 'v_x', 'v_y'],
            ['a_x', 'a_y']
        )
        
        # 势能模块
        self.potential = KANModule(
            in_features=2,  # [x, y]
            hidden_features=10,
            out_features=1,  # V(x,y)
            num_layers=3
        )
        self.potential.set_names(['x', 'y'], ['V'])
        
        # 梯度模块(用于保守力)
        self.gradient = GradientKAN(in_features=2, out_features=2)
        
        # 连接模块
        self.conservation_law = nn.Parameter(torch.eye(2))  # 能量守恒约束
    
    def forward(self, x, y, vx, vy):
        """
        前向传播:计算加速度
        """
        state = torch.stack([x, y, vx, vy], dim=-1)
        
        # 方法1:直接动力学预测
        acc_direct = self.dynamics(state)
        
        # 方法2:通过势能梯度计算(保守力)
        V = self.potential(torch.stack([x, y], dim=-1))
        acc_gradient = -self.gradient(V)
        
        # 结合两种方法
        acc = 0.5 * acc_direct + 0.5 * acc_gradient
        
        return acc

4. KAN 2.0 的技术细节

4.1 混合激活函数设计

class HybridKANLayer(nn.Module):
    """
    混合激活 KAN Layer
    
    结合多种激活函数类型的优势
    """
    def __init__(self, in_features, out_features, 
                 grid_size=5, spline_order=3):
        super().__init__()
        
        # B-样条激活
        self.spline = BSplineActivation(in_features, out_features, 
                                        grid_size, spline_order)
        
        # 专家混合激活
        self.symbolic_experts = nn.ModuleList([
            SymbolicKANLayer(in_features, out_features)
            for _ in range(4)  # 4个专家
        ])
        
        # 路由器(决定使用哪个专家)
        self.router = nn.Linear(in_features, 4)
        
        # 注意力权重
        self.attention = nn.Sequential(
            nn.Linear(in_features, 1),
            nn.Softmax(dim=-1)
        )
        
        # 门控
        self.spline_gate = nn.Parameter(torch.tensor(0.5))
        self.expert_gate = nn.Parameter(torch.tensor(0.5))
    
    def forward(self, x):
        # B-样条输出
        spline_out = self.spline(x)
        
        # 专家混合输出
        expert_outs = [expert(x) for expert in self.symbolic_experts]
        expert_outs = torch.stack(expert_outs, dim=0)  # (4, B, out)
        
        # 计算注意力权重
        attn_weights = self.attention(x)  # (B, 4)
        attn_weights = attn_weights.unsqueeze(1)  # (B, 1, 4)
        
        # 加权平均专家输出
        expert_out = (attn_weights * expert_outs.transpose(0, 1)).sum(dim=-1)
        
        # 门控混合
        out = (self.spline_gate * spline_out + 
               self.expert_gate * expert_out)
        
        return out

4.2 物理约束注入

class PhysicsInformedKAN(nn.Module):
    """
    物理信息 KAN (PIKAN)
    
    将物理定律作为硬约束注入网络
    """
    def __init__(self, physics_constraints):
        super().__init__()
        
        # 可学习的 KAN 部分
        self.kan = KAN([2, 10, 10, 1])
        
        # 物理约束
        self.physics_constraints = physics_constraints  # e.g., ['conservation', 'symmetry']
    
    def apply_constraints(self, x, pred):
        """
        应用物理约束修正预测
        """
        constrained_pred = pred.clone()
        
        for constraint in self.physics_constraints:
            if constraint == 'conservation':
                # 能量守恒:总能量应该恒定
                # 修正预测以满足守恒
                pass
            
            elif constraint == 'symmetry':
                # 对称性约束:f(x,y) = f(y,x)
                x_flipped = torch.flip(x, dims=[-1])
                pred_flipped = self.kan(x_flipped)
                constrained_pred = (constrained_pred + pred_flipped) / 2
            
            elif constraint == 'positivity':
                # 正性约束:某些物理量必须为正
                constrained_pred = torch.clamp(constrained_pred, min=0)
        
        return constrained_pred
    
    def forward(self, x):
        pred = self.kan(x)
        constrained_pred = self.apply_constraints(x, pred)
        return constrained_pred

5. 与传统方法的对比

5.1 vs Symbolic Regression

特性Symbolic RegressionKAN 2.0
搜索策略离散搜索(遗传编程)连续优化(梯度下降)
公式复杂度适合简单公式适合复杂非线性
可解释性中-高
计算效率低(组合爆炸)
准确性

5.2 vs PINNs

特性PINNsKAN 2.0
物理约束软约束(损失项)软约束(损失项)
公式发现
可解释性
训练稳定性
适用场景PDE求解公式发现

6. 实践指南

6.1 KAN 2.0 训练策略

def train_kan_2(model, train_loader, val_loader=None, epochs=500, lr=1e-3):
    """
    KAN 2.0 训练策略
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=50, T_mult=2
    )
    
    best_loss = float('inf')
    patience = 50
    patience_counter = 0
    
    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            
            # 输入标准化到 [0, 1]
            x = standardize_to_unit_range(x)
            
            pred = model(x)
            loss = F.mse_loss(pred, y)
            
            # 添加稀疏性正则化
            sparsity_loss = compute_sparsity_loss(model)
            total_loss = loss + 0.01 * sparsity_loss
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        scheduler.step()
        
        # 验证阶段
        if val_loader is not None:
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x = standardize_to_unit_range(x)
                    pred = model(x)
                    val_loss += F.mse_loss(pred, y).item()
            
            # 早停
            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}, Train Loss: {train_loss/len(train_loader):.6f}")
 
 
def standardize_to_unit_range(x):
    """将输入标准化到 [0, 1] 范围"""
    x_min = x.min(dim=1, keepdim=True)[0]
    x_max = x.max(dim=1, keepdim=True)[0]
    x_std = (x - x_min) / (x_max - x_min + 1e-8)
    return x_std.clamp(0, 1)
 
 
def compute_sparsity_loss(model, lambda_l1=0.01):
    """
    计算稀疏性损失
    
    促使网络学习更简单的激活函数
    """
    loss = 0
    for layer in model.layers:
        if hasattr(layer, 'coeff'):
            # L1 正则化促进稀疏性
            loss += lambda_l1 * torch.abs(layer.coeff).mean()
        
        if hasattr(layer, 'symbolic_weight'):
            loss += lambda_l1 * torch.abs(layer.symbolic_weight).mean()
    
    return loss

6.2 公式提取的最佳实践

def best_practices_formula_extraction():
    """
    公式提取的最佳实践
    """
    
    # 1. 训练要足够长,但不要过拟合
    # 使用验证集选择最佳模型
    
    # 2. 使用适当的网络宽度
    # 太宽:容易过拟合,公式复杂
    # 太窄:表达能力不足
    
    # 3. 激活函数选择
    # 物理问题:sin, cos, exp, log
    # 一般问题:B-样条
    
    # 4. 网格大小选择
    # 简单函数:grid_size=3, spline_order=2
    # 复杂函数:grid_size=10, spline_order=3
    
    # 5. 后处理
    # 使用 SymPy 简化提取的公式
    from sympy import simplify, expand, trigsimp
    
    def simplify_extracted_formula(expr):
        """简化提取的公式"""
        expr = simplify(expr)
        expr = expand(expr)
        expr = trigsimp(expr)
        return expr
    
    pass

7. 总结与展望

KAN 2.0 的主要贡献

  1. 符号公式发现:将KAN从拟合工具升级为科学发现工具
  2. 特征重要性分析:通过跳跃连接提供系统化的重要性评估
  3. 模块化设计:支持构建复杂科学系统
  4. 物理约束注入:支持将先验知识融入网络

局限性

  1. 计算成本:B-样条激活的计算量仍然较大
  2. 规模化挑战:在大型数据集上的效率不如Transformer
  3. 激活函数选择:需要领域知识选择合适的符号函数

未来方向

  1. 与LLM结合:使用语言模型辅助公式解释
  2. 自动化模块设计:自动发现最优模块结构
  3. 多物理场耦合:处理多物理场耦合问题

参考


相关阅读

Footnotes

  1. Liu, Z., et al. (2024). “KAN 2.0: Kolmogorov-Arnold Networks Meet Science”. arXiv:2408.10205.