概述
KAN 2.0(Kolmogorov-Arnold Networks Meet Science)是对原始KAN的重大升级,专注于科学发现应用。1 MIT团队在2024年8月发布了KAN 2.0,提出了三大核心功能:符号公式发现、特征重要性分析和模块化设计。本文深入解析KAN 2.0的核心原理及其在科学研究中的应用价值。
1. KAN 2.0 vs KAN 1.0:关键改进
核心差异对比
| 特性 | KAN 1.0 | KAN 2.0 |
|---|---|---|
| 激活函数 | B-样条 | B-样条 + 符号函数 + 专家混合 |
| 可解释性 | 激活函数可视化 | 符号公式提取 |
| 模块化 | 单一网络 | 模块化、层级化设计 |
| 训练 | 标准反向传播 | 渐进式训练 + 剪枝 |
| 科学应用 | 概念验证 | 系统化应用框架 |
KAN 1.0 回顾
原始KAN将Kolmogorov-Arnold表示定理应用于神经网络,把激活函数从节点移到边上:
其中 是可学习的激活函数(边上的B-样条)。
2. KAN 2.0 的三大核心功能
2.1 符号公式发现
KAN 2.0的核心创新是能够自动提取符号公式。这使得KAN不仅能拟合数据,还能揭示潜在的物理规律。
符号激活函数
class SymbolicActivation(nn.Module):
"""
符号激活函数库
KAN 2.0 引入了多种可解释的符号函数
"""
def __init__(self):
super().__init__()
# 预定义的符号函数
self.symbolic_functions = {
'sin': torch.sin,
'cos': torch.cos,
'exp': torch.exp,
'log': torch.log,
'sqrt': torch.sqrt,
'abs': torch.abs,
'square': lambda x: x ** 2,
'inverse': lambda x: 1 / (x + 1e-8),
}
# 可学习的符号激活
self.learnable_symbolic = nn.Parameter(torch.ones(1))
def forward(self, x, func_name='sin'):
"""应用符号函数"""
if func_name in self.symbolic_functions:
return self.symbolic_functions[func_name](x)
else:
# 默认为自己
return x
class SymbolicKANLayer(nn.Module):
"""
符号 KAN Layer
结合 B-样条和符号函数的混合激活
"""
def __init__(self, in_features, out_features,
symbolic_funcs=['sin', 'cos', 'exp', 'log'],
grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# B-样条激活
self.spline = BSplineActivation(in_features, out_features,
grid_size, spline_order)
# 符号激活
self.symbolic = SymbolicActivation()
self.symbolic_weight = nn.Parameter(torch.zeros(len(symbolic_funcs),
out_features, in_features))
# 符号函数选择器
self.func_names = symbolic_funcs
# 基础激活
self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, x):
# B-样条部分
spline_out = self.spline(x)
# 符号激活部分
symbolic_out = 0
for i, func_name in enumerate(self.func_names):
func_out = self.symbolic(x, func_name)
symbolic_out += torch.einsum('bi,oi->bo',
func_out * self.symbolic_weight[i],
torch.ones(self.out_features))
# 基础激活
base_out = torch.einsum('bi,oi->bo', x,
torch.nn.functional.silu(self.base_weight))
return spline_out + symbolic_out + base_out公式提取算法
def extract_symbolic_formula(model, input_names, output_name='f',
threshold=0.01, top_k=5):
"""
从训练好的 KAN 中提取符号公式
Args:
model: 训练好的 KAN 模型
input_names: 输入变量名列表
output_name: 输出变量名
threshold: 系数阈值
top_k: 每层保留的 top-k 激活
Returns:
symbolic_expression: SymPy 符号表达式
"""
from sympy import symbols, Function, sin, cos, exp, log, sqrt
# 定义符号
x = [symbols(name) for name in input_names]
expression = 0
# 遍历每一层
for layer_idx, layer in enumerate(model.layers):
layer_expr = 0
# 获取激活函数系数
coeff = layer.coeff.data.abs().mean(dim=0) # (in, out)
# 对每个输出单元
for out_idx in range(layer.out_features):
out_expr = 0
# 获取最强的输入连接
strengths = coeff[:, out_idx]
top_indices = torch.topk(strengths, min(top_k, strengths.numel())).indices
for in_idx in top_indices:
# 确定激活函数类型(通过可视化或拟合)
func_type = identify_activation_function(
layer.coeff.data[out_idx, in_idx]
)
# 构建符号表达式
if func_type == 'linear':
coef = layer.coeff.data[out_idx, in_idx].mean().item()
out_expr += coef * x[in_idx]
elif func_type == 'sin':
out_expr += symbols(f'a_{layer_idx}_{in_idx}_{out_idx}') * sin(x[in_idx])
elif func_type == 'exp':
out_expr += symbols(f'b_{layer_idx}_{in_idx}_{out_idx}') * exp(x[in_idx])
# ... 其他函数类型
layer_expr += out_expr
expression = expression + layer_expr
return expression
def identify_activation_function(coeff, grid):
"""
识别激活函数的类型
通过拟合不同函数来确定类型
"""
import numpy as np
from scipy.optimize import curve_fit
# 定义候选函数
candidates = {
'linear': lambda x, a, b: a * x + b,
'sin': lambda x, a, b, c: a * np.sin(b * x + c),
'cos': lambda x, a, b, c: a * np.cos(b * x + c),
'exp': lambda x, a, b: a * np.exp(b * x),
'log': lambda x, a, b: a * np.log(np.abs(x) + 1e-8) + b,
}
# 简化的识别方法
x = np.linspace(0, 1, len(coeff))
y = coeff.numpy()
# 计算拟合误差
best_func = 'linear'
best_error = float('inf')
for name, func in candidates.items():
try:
popt, _ = curve_fit(func, x, y, maxfev=5000)
y_pred = func(x, *popt)
error = np.mean((y - y_pred) ** 2)
if error < best_error:
best_error = error
best_func = name
except:
continue
return best_func2.2 特征重要性分析
KAN 2.0提供了系统化的特征重要性评估方法,帮助理解哪些输入特征对输出最重要。
跳跃(Jumping)机制
class JumpKANLayer(nn.Module):
"""
带跳跃连接的 KAN Layer
允许信息直接从输入跳到输出层,提高效率
"""
def __init__(self, in_features, out_features, grid_size=5,
spline_order=3, jump_strength=0.1):
super().__init__()
# 标准 KAN 激活
self.kan_activation = KANLayer(in_features, out_features,
grid_size, spline_order)
# 跳跃连接:直接从输入到输出
self.jump_weight = nn.Parameter(
torch.randn(out_features, in_features) * jump_strength
)
# 跳跃连接的重要性权重
self.jump_importance = nn.Parameter(torch.ones(in_features))
def forward(self, x):
# 标准 KAN 激活
kan_out = self.kan_activation(x)
# 跳跃连接(可解释的线性贡献)
jump_out = torch.einsum('bi,oi,oi->bo',
x, self.jump_weight, self.jump_importance)
return kan_out + jump_out
def compute_feature_importance(model, x, y_true=None):
"""
计算输入特征的重要性
基于跳跃连接权重和激活强度
"""
importance_scores = []
with torch.no_grad():
for name, module in model.named_modules():
if hasattr(module, 'jump_weight'):
# 跳跃连接权重表示直接贡献
jump_imp = module.jump_weight.data.abs().mean(dim=0)
jump_imp = jump_imp * module.jump_importance.data
importance_scores.append(jump_imp)
if hasattr(module, 'coeff'):
# B-样条系数表示非线性贡献
spline_imp = module.coeff.data.abs().mean(dim=(0, 2))
importance_scores.append(spline_imp)
# 聚合所有层的贡献
total_importance = torch.zeros_like(importance_scores[0])
for imp in importance_scores:
total_importance = total_importance + imp
# 归一化
total_importance = total_importance / total_importance.sum()
return total_importance.numpy()2.3 模块化设计
KAN 2.0引入了模块化概念,允许组合多个KAN子网络形成更大的系统。
class KANModule(nn.Module):
"""
KAN 模块
一个可重复使用的 KAN 子模块
"""
def __init__(self, in_features, hidden_features, out_features,
num_layers=2, grid_size=5, spline_order=3):
super().__init__()
# 构建模块内部结构
layer_dims = [in_features] + [hidden_features] * num_layers + [out_features]
self.kan = KAN(layer_dims, grid_size, spline_order)
# 模块的输入/输出名
self.input_names = None
self.output_names = None
def set_names(self, input_names, output_names):
"""设置模块的输入输出名称"""
self.input_names = input_names
self.output_names = output_names
def forward(self, x):
return self.kan(x)
class ModularKAN(nn.Module):
"""
模块化 KAN
由多个 KANModule 组成的层级系统
"""
def __init__(self):
super().__init__()
# 定义模块
self.module_A = KANModule(in_features=2, hidden_features=5,
out_features=3, num_layers=2)
self.module_B = KANModule(in_features=3, hidden_features=5,
out_features=1, num_layers=2)
# 模块间的连接
self.connection_A_to_B = nn.Parameter(torch.randn(3, 3))
# 设置模块名称
self.module_A.set_names(['x1', 'x2'], ['h1', 'h2', 'h3'])
self.module_B.set_names(['h1', 'h2', 'h3'], ['y'])
def forward(self, x):
# 模块 A
h = self.module_A(x)
# 模块 B
y = self.module_B(h)
return y
def get_module_graph(self):
"""
获取模块依赖图(用于可视化和分析)
"""
return {
'modules': ['module_A', 'module_B'],
'connections': [
('module_A', 'module_B', self.connection_A_to_B.shape)
],
'input_names': self.module_A.input_names,
'output_names': self.module_B.output_names
}3. 科学发现应用案例
3.1 物理定律发现
KAN 2.0在发现物理定律方面展现了强大能力。MIT团队展示了KAN如何从数据中恢复出已知物理公式。
def discover_physics_laws():
"""
示例:使用 KAN 发现物理定律
假设我们要发现开普勒第三定律:T² ∝ a³
"""
import torch
import numpy as np
# 生成数据:行星轨道周期和半长轴
np.random.seed(42)
a = np.random.uniform(0.5, 10, 100) # 半长轴 (AU)
# 开普勒第三定律: T² = a³ (归一化常数)
T = np.sqrt(a ** 3) + np.random.normal(0, 0.1, 100) # 周期
# 准备数据
x_train = torch.tensor(a.reshape(-1, 1), dtype=torch.float32)
y_train = torch.tensor(T.reshape(-1, 1), dtype=torch.float32)
# 标准化
x_mean, x_std = x_train.mean(), x_train.std()
y_mean, y_std = y_train.mean(), y_train.std()
x_train = (x_train - x_mean) / x_std
y_train = (y_train - y_mean) / y_std
# 构建 KAN
model = KAN([1, 5, 5, 1], grid_size=5, spline_order=3)
# 训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(500):
optimizer.zero_grad()
pred = model(x_train)
loss = F.mse_loss(pred, y_train)
loss.backward()
optimizer.step()
# 提取公式
formula = extract_symbolic_formula(
model,
input_names=['a'],
output_name='T'
)
print(f"发现的公式: T = {formula}")
# 预期: T = a^(3/2) 或类似形式
def discover_lorentz_force():
"""
示例:发现洛伦兹力
F = q(E + v × B)
"""
# 生成合成数据
q = torch.randn(1000) # 电荷
E = torch.randn(1000, 3) # 电场
v = torch.randn(1000, 3) # 速度
B = torch.randn(1000, 3) # 磁场
# 计算力(简化模型)
F = q.unsqueeze(-1) * (E + torch.cross(v, B))
# 构建 KAN 模型
model = KAN([10, 20, 20, 3], grid_size=5, spline_order=3)
# 训练并提取公式
# ...3.2 模块化科学系统
class ScientificKAN(nn.Module):
"""
科学应用的模块化 KAN 系统
用于组合多个物理子系统
"""
def __init__(self):
super().__init__()
# 动力学模块
self.dynamics = KANModule(
in_features=4, # [x, y, vx, vy]
hidden_features=10,
out_features=2, # [ax, ay]
num_layers=3
)
self.dynamics.set_names(
['x', 'y', 'v_x', 'v_y'],
['a_x', 'a_y']
)
# 势能模块
self.potential = KANModule(
in_features=2, # [x, y]
hidden_features=10,
out_features=1, # V(x,y)
num_layers=3
)
self.potential.set_names(['x', 'y'], ['V'])
# 梯度模块(用于保守力)
self.gradient = GradientKAN(in_features=2, out_features=2)
# 连接模块
self.conservation_law = nn.Parameter(torch.eye(2)) # 能量守恒约束
def forward(self, x, y, vx, vy):
"""
前向传播:计算加速度
"""
state = torch.stack([x, y, vx, vy], dim=-1)
# 方法1:直接动力学预测
acc_direct = self.dynamics(state)
# 方法2:通过势能梯度计算(保守力)
V = self.potential(torch.stack([x, y], dim=-1))
acc_gradient = -self.gradient(V)
# 结合两种方法
acc = 0.5 * acc_direct + 0.5 * acc_gradient
return acc4. KAN 2.0 的技术细节
4.1 混合激活函数设计
class HybridKANLayer(nn.Module):
"""
混合激活 KAN Layer
结合多种激活函数类型的优势
"""
def __init__(self, in_features, out_features,
grid_size=5, spline_order=3):
super().__init__()
# B-样条激活
self.spline = BSplineActivation(in_features, out_features,
grid_size, spline_order)
# 专家混合激活
self.symbolic_experts = nn.ModuleList([
SymbolicKANLayer(in_features, out_features)
for _ in range(4) # 4个专家
])
# 路由器(决定使用哪个专家)
self.router = nn.Linear(in_features, 4)
# 注意力权重
self.attention = nn.Sequential(
nn.Linear(in_features, 1),
nn.Softmax(dim=-1)
)
# 门控
self.spline_gate = nn.Parameter(torch.tensor(0.5))
self.expert_gate = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
# B-样条输出
spline_out = self.spline(x)
# 专家混合输出
expert_outs = [expert(x) for expert in self.symbolic_experts]
expert_outs = torch.stack(expert_outs, dim=0) # (4, B, out)
# 计算注意力权重
attn_weights = self.attention(x) # (B, 4)
attn_weights = attn_weights.unsqueeze(1) # (B, 1, 4)
# 加权平均专家输出
expert_out = (attn_weights * expert_outs.transpose(0, 1)).sum(dim=-1)
# 门控混合
out = (self.spline_gate * spline_out +
self.expert_gate * expert_out)
return out4.2 物理约束注入
class PhysicsInformedKAN(nn.Module):
"""
物理信息 KAN (PIKAN)
将物理定律作为硬约束注入网络
"""
def __init__(self, physics_constraints):
super().__init__()
# 可学习的 KAN 部分
self.kan = KAN([2, 10, 10, 1])
# 物理约束
self.physics_constraints = physics_constraints # e.g., ['conservation', 'symmetry']
def apply_constraints(self, x, pred):
"""
应用物理约束修正预测
"""
constrained_pred = pred.clone()
for constraint in self.physics_constraints:
if constraint == 'conservation':
# 能量守恒:总能量应该恒定
# 修正预测以满足守恒
pass
elif constraint == 'symmetry':
# 对称性约束:f(x,y) = f(y,x)
x_flipped = torch.flip(x, dims=[-1])
pred_flipped = self.kan(x_flipped)
constrained_pred = (constrained_pred + pred_flipped) / 2
elif constraint == 'positivity':
# 正性约束:某些物理量必须为正
constrained_pred = torch.clamp(constrained_pred, min=0)
return constrained_pred
def forward(self, x):
pred = self.kan(x)
constrained_pred = self.apply_constraints(x, pred)
return constrained_pred5. 与传统方法的对比
5.1 vs Symbolic Regression
| 特性 | Symbolic Regression | KAN 2.0 |
|---|---|---|
| 搜索策略 | 离散搜索(遗传编程) | 连续优化(梯度下降) |
| 公式复杂度 | 适合简单公式 | 适合复杂非线性 |
| 可解释性 | 高 | 中-高 |
| 计算效率 | 低(组合爆炸) | 中 |
| 准确性 | 中 | 高 |
5.2 vs PINNs
| 特性 | PINNs | KAN 2.0 |
|---|---|---|
| 物理约束 | 软约束(损失项) | 软约束(损失项) |
| 公式发现 | 否 | 是 |
| 可解释性 | 低 | 高 |
| 训练稳定性 | 中 | 高 |
| 适用场景 | PDE求解 | 公式发现 |
6. 实践指南
6.1 KAN 2.0 训练策略
def train_kan_2(model, train_loader, val_loader=None, epochs=500, lr=1e-3):
"""
KAN 2.0 训练策略
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=50, T_mult=2
)
best_loss = float('inf')
patience = 50
patience_counter = 0
for epoch in range(epochs):
# 训练阶段
model.train()
train_loss = 0
for x, y in train_loader:
optimizer.zero_grad()
# 输入标准化到 [0, 1]
x = standardize_to_unit_range(x)
pred = model(x)
loss = F.mse_loss(pred, y)
# 添加稀疏性正则化
sparsity_loss = compute_sparsity_loss(model)
total_loss = loss + 0.01 * sparsity_loss
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# 验证阶段
if val_loader is not None:
model.eval()
val_loss = 0
with torch.no_grad():
for x, y in val_loader:
x = standardize_to_unit_range(x)
pred = model(x)
val_loss += F.mse_loss(pred, y).item()
# 早停
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
if epoch % 50 == 0:
print(f"Epoch {epoch}, Train Loss: {train_loss/len(train_loader):.6f}")
def standardize_to_unit_range(x):
"""将输入标准化到 [0, 1] 范围"""
x_min = x.min(dim=1, keepdim=True)[0]
x_max = x.max(dim=1, keepdim=True)[0]
x_std = (x - x_min) / (x_max - x_min + 1e-8)
return x_std.clamp(0, 1)
def compute_sparsity_loss(model, lambda_l1=0.01):
"""
计算稀疏性损失
促使网络学习更简单的激活函数
"""
loss = 0
for layer in model.layers:
if hasattr(layer, 'coeff'):
# L1 正则化促进稀疏性
loss += lambda_l1 * torch.abs(layer.coeff).mean()
if hasattr(layer, 'symbolic_weight'):
loss += lambda_l1 * torch.abs(layer.symbolic_weight).mean()
return loss6.2 公式提取的最佳实践
def best_practices_formula_extraction():
"""
公式提取的最佳实践
"""
# 1. 训练要足够长,但不要过拟合
# 使用验证集选择最佳模型
# 2. 使用适当的网络宽度
# 太宽:容易过拟合,公式复杂
# 太窄:表达能力不足
# 3. 激活函数选择
# 物理问题:sin, cos, exp, log
# 一般问题:B-样条
# 4. 网格大小选择
# 简单函数:grid_size=3, spline_order=2
# 复杂函数:grid_size=10, spline_order=3
# 5. 后处理
# 使用 SymPy 简化提取的公式
from sympy import simplify, expand, trigsimp
def simplify_extracted_formula(expr):
"""简化提取的公式"""
expr = simplify(expr)
expr = expand(expr)
expr = trigsimp(expr)
return expr
pass7. 总结与展望
KAN 2.0 的主要贡献
- 符号公式发现:将KAN从拟合工具升级为科学发现工具
- 特征重要性分析:通过跳跃连接提供系统化的重要性评估
- 模块化设计:支持构建复杂科学系统
- 物理约束注入:支持将先验知识融入网络
局限性
- 计算成本:B-样条激活的计算量仍然较大
- 规模化挑战:在大型数据集上的效率不如Transformer
- 激活函数选择:需要领域知识选择合适的符号函数
未来方向
- 与LLM结合:使用语言模型辅助公式解释
- 自动化模块设计:自动发现最优模块结构
- 多物理场耦合:处理多物理场耦合问题
参考
相关阅读
Footnotes
-
Liu, Z., et al. (2024). “KAN 2.0: Kolmogorov-Arnold Networks Meet Science”. arXiv:2408.10205. ↩