概述
本文提供Kolmogorov-Arnold Networks (KAN)的完整PyTorch实现,包括基础KAN Layer、完整KAN模型、训练工具和最佳实践指南。代码经过模块化设计,便于理解和扩展。
1. 基础组件实现
1.1 B-样条激活函数
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class BSplineActivation(nn.Module):
"""
B-样条激活函数
实现 KAN 中使用的可学习 B-样条激活
"""
def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# 定义网格范围
self.grid_range = (0, 1)
# 创建网格点
h = (self.grid_range[1] - self.grid_range[0]) / grid_size
self.register_buffer(
'grid',
torch.linspace(
self.grid_range[0] - h * spline_order,
self.grid_range[1] + h * spline_order,
grid_size + 2 * spline_order + 1
)
)
# B-样条系数(可学习参数)
self.coeff = nn.Parameter(
torch.randn(out_features, in_features, grid_size + spline_order)
)
# 初始化
self._init_coeff()
def _init_coeff(self):
"""初始化系数为零"""
nn.init.zeros_(self.coeff)
def de_boor(self, x, grid, k):
"""
De Boor 算法计算 B-样条值
Args:
x: (batch, in_features) 输入点
grid: 网格点
k: 样条阶数
Returns:
bases: (batch, in_features, n_bases) 基函数值
"""
batch_size, in_features = x.shape
# 确保 x 在网格范围内
x = x.clamp(self.grid_range[0], self.grid_range[1])
# 计算 B-样条基函数
n_bases = len(grid) - k - 1
bases = torch.zeros(batch_size, in_features, n_bases, device=x.device)
# 0阶基函数
for i in range(n_bases):
left = (grid[i] <= x.float()) & (x.float() < grid[i + 1])
right = (x.float() == grid[-1]) & (i == n_bases - 1)
bases[:, :, i] = (left | right).float()
# 递归计算高阶基函数
for order in range(1, k + 1):
for i in range(n_bases - order):
# 左侧项
denom_left = grid[i + order] - grid[i] + 1e-8
left = (x - grid[i]) / denom_left * bases[:, :, i]
# 右侧项
denom_right = grid[i + order + 1] - grid[i + 1] + 1e-8
right = (grid[i + order + 1] - x) / denom_right * bases[:, :, i + 1]
bases[:, :, i] = left + right
bases = bases[:, :, :n_bases - order]
return bases
def forward(self, x):
"""
前向传播
Args:
x: (batch, in_features) 输入,范围 [0, 1]
Returns:
y: (batch, out_features)
"""
# 计算 B-样条基函数
bases = self.de_boor(x, self.grid, self.spline_order)
# 加权求和
# bases: (batch, in, n_bases)
# coeff: (out, in, n_bases)
# output: (batch, out)
output = torch.einsum('bik,oik->bo', bases, self.coeff)
return output
class EfficientBSplineActivation(nn.Module):
"""
高效 B-样条激活(使用向量化的实现)
"""
def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# 网格
h = 1.0 / grid_size
self.register_buffer(
'grid',
torch.linspace(-h * spline_order, 1 + h * spline_order,
grid_size + 2 * spline_order + 1)
)
# 系数
self.coeff = nn.Parameter(
torch.randn(out_features, in_features, grid_size + spline_order)
)
self._init_coeff()
def _init_coeff(self):
nn.init.zeros_(self.coeff)
def forward(self, x):
"""
高效的前向传播
"""
batch_size, in_features = x.shape
# 映射到 [0, 1]
x = x.clamp(0, 1)
# 简化的 B-样条计算
# 使用一阶(线性)样条作为示例
grid = self.grid
k = self.spline_order
# 计算插值权重
x_expanded = x.unsqueeze(-1) # (batch, in, 1)
# 找到 x 在网格中的位置
indices = torch.searchsorted(grid[1:-1], x_expanded) # (batch, in, 1)
indices = indices.clamp(0, len(grid) - 2)
# 计算权重
x1 = grid[indices] # (batch, in, 1)
x2 = grid[indices + 1]
t = (x_expanded - x1) / (x2 - x1 + 1e-8)
t = t.squeeze(-1) # (batch, in)
# 计算线性插值
coeff_low = self.coeff[:, :, :self.grid_size]
coeff_high = self.coeff[:, :, 1:self.grid_size + 1]
spline_output = (1 - t.unsqueeze(1)) * coeff_low + t.unsqueeze(1) * coeff_high
spline_output = spline_output.sum(dim=-1) # (batch, out)
return spline_output1.2 KAN Layer
class KANLayer(nn.Module):
"""
KAN Layer
KAN 的基本构建块
"""
def __init__(self, in_features, out_features,
grid_size=5, spline_order=3,
base_activation='silu',
use_base_activation=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# B-样条激活
self.spline = BSplineActivation(
in_features, out_features,
grid_size, spline_order
)
# 基础激活权重
if use_base_activation:
self.base_weight = nn.Parameter(
torch.randn(out_features, in_features)
)
else:
self.register_parameter('base_weight', None)
self.use_base_activation = use_base_activation
# 激活函数
if base_activation == 'silu':
self.activation_fn = nn.functional.silu
elif base_activation == 'gelu':
self.activation_fn = nn.functional.gelu
elif base_activation == 'relu':
self.activation_fn = nn.functional.relu
elif base_activation == 'tanh':
self.activation_fn = torch.tanh
else:
self.activation_fn = nn.functional.silu
# 初始化
self._init_weights()
def _init_weights(self):
"""初始化权重"""
if self.base_weight is not None:
nn.init.normal_(self.base_weight, std=0.1)
def forward(self, x):
"""
前向传播
Args:
x: (batch, in_features) 输入
Returns:
y: (batch, out_features) 输出
"""
# 确保输入在 [0, 1]
x = x.clamp(0, 1)
# B-样条激活
spline_out = self.spline(x)
# 基础激活
if self.use_base_activation and self.base_weight is not None:
base_out = torch.einsum('bi,oi->bo',
self.activation_fn(x),
self.base_weight)
return spline_out + base_out
return spline_out
class KANLinear(nn.Module):
"""
KAN 的线性层版本(用于特征维度变换)
"""
def __init__(self, in_features, out_features,
grid_size=5, spline_order=3):
super().__init__()
self.kan = KANLayer(in_features, out_features, grid_size, spline_order)
def forward(self, x):
return self.kan(x)2. 完整 KAN 模型
2.1 基础 KAN
class KAN(nn.Module):
"""
完整的 Kolmogorov-Arnold Network
Args:
layer_dims: 每层的维度列表,如 [2, 3, 5, 1]
grid_size: B-样条网格大小
spline_order: B-样条阶数
base_activation: 基础激活函数
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3,
base_activation='silu', use_base_activation=True):
super().__init__()
self.layer_dims = layer_dims
self.num_layers = len(layer_dims) - 1
# 创建层
self.layers = nn.ModuleList()
for i in range(self.num_layers):
self.layers.append(
KANLayer(
in_features=layer_dims[i],
out_features=layer_dims[i + 1],
grid_size=grid_size,
spline_order=spline_order,
base_activation=base_activation,
use_base_activation=use_base_activation
)
)
# 激活函数(层间)
self.activation = nn.SiLU() if base_activation == 'silu' else nn.GELU()
def forward(self, x):
"""
前向传播
Args:
x: (batch, in_features) 或 (batch, seq_len, in_features)
Returns:
y: (batch, out_features) 或 (batch, seq_len, out_features)
"""
# 处理不同输入形状
original_shape = x.shape
if len(original_shape) == 3:
# (batch, seq_len, features) -> (batch*seq_len, features)
batch_size, seq_len, in_features = x.shape
x = x.reshape(batch_size * seq_len, in_features)
need_reshape = True
else:
need_reshape = False
# 前向传播
for i, layer in enumerate(self.layers):
x = layer(x)
if i < self.num_layers - 1:
x = self.activation(x)
# 恢复形状
if need_reshape:
x = x.reshape(batch_size, seq_len, -1)
return x
def get_regularization_loss(self, lambda_l1=0.01):
"""
计算 L1 正则化损失(促进稀疏性)
"""
loss = 0
for layer in self.layers:
loss += lambda_l1 * torch.abs(layer.spline.coeff).mean()
return loss
def get_num_parameters(self):
"""获取参数数量"""
return sum(p.numel() for p in self.parameters())2.2 带跳跃连接的 KAN
class ResidualKAN(nn.Module):
"""
带跳跃连接的 KAN
类似于 ResNet,提高训练稳定性
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3,
base_activation='silu', dropout=0.0):
super().__init__()
self.kan = KAN(
layer_dims, grid_size, spline_order,
base_activation, use_base_activation=True
)
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
# 跳跃连接(如果维度不匹配)
if layer_dims[0] != layer_dims[-1]:
self.skip = nn.Linear(layer_dims[0], layer_dims[-1])
else:
self.skip = nn.Identity()
# 门控
self.gate = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
identity = self.skip(x)
out = self.kan(x)
if self.dropout is not None:
out = self.dropout(out)
# 门控跳跃连接
out = self.gate * out + (1 - self.gate) * identity
return out
class DeepKAN(nn.Module):
"""
深度 KAN(带层级归一化)
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3,
base_activation='silu', dropout=0.1):
super().__init__()
self.layers = nn.ModuleList()
self_norms = nn.ModuleList()
for i in range(len(layer_dims) - 1):
self.layers.append(
KANLayer(
layer_dims[i], layer_dims[i + 1],
grid_size, spline_order, base_activation
)
)
self_norms.append(
nn.LayerNorm(layer_dims[i + 1])
)
self.norms = nn.ModuleList(self_norms)
self.activation = nn.SiLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
for layer, norm in zip(self.layers, self.norms):
x = norm(layer(x))
x = self.activation(x)
x = self.dropout(x)
return x2.3 可解释性增强的 KAN
class SparseKAN(nn.Module):
"""
稀疏 KAN(促进可解释性)
通过 L1 正则化和阈值化实现稀疏激活
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3,
sparsity_threshold=0.01):
super().__init__()
self.kan = KAN(layer_dims, grid_size, spline_order)
self.sparsity_threshold = sparsity_threshold
def forward(self, x):
return self.kan(x)
def apply_sparsity(self):
"""
应用稀疏性:将小系数置零
"""
with torch.no_grad():
for layer in self.kan.layers:
layer.spline.coeff[
torch.abs(layer.spline.coeff) < self.sparsity_threshold
] = 0
def get_activation_importance(self):
"""
获取激活函数的重要性分数
"""
importance = {}
for i, layer in enumerate(self.kan.layers):
coeff_abs_mean = torch.abs(layer.spline.coeff).mean(dim=(0, 2))
importance[f'layer_{i}'] = coeff_abs_mean.cpu().numpy()
return importance
class ModularKAN(nn.Module):
"""
模块化 KAN(支持可解释性分组)
"""
def __init__(self, module_configs, grid_size=5, spline_order=3):
"""
Args:
module_configs: 列表,每个元素是 (in_features, out_features, name)
"""
super().__init__()
self.modules = nn.ModuleDict()
for in_f, out_f, name in module_configs:
self.modules[name] = KANLayer(in_f, out_f, grid_size, spline_order)
def forward(self, x_dict):
"""
x_dict: 字典,键是模块名,值是输入张量
"""
outputs = {}
for name, module in self.modules.items():
if name in x_dict:
outputs[name] = module(x_dict[name])
return outputs
def visualize_activations(self):
"""
可视化激活函数
"""
for name, module in self.modules.items():
print(f"\nModule: {name}")
print(f" Coefficient shape: {module.spline.coeff.shape}")
print(f" Coefficient range: [{module.spline.coeff.min():.4f}, "
f"{module.spline.coeff.max():.4f}]")3. 训练工具
3.1 数据集和标准化
class Standardizer:
"""
数据标准化工具
将数据标准化到 [0, 1] 范围(KAN 需要)
"""
def __init__(self, x_mean=None, x_std=None,
x_min=None, x_max=None,
mode='minmax'):
self.mode = mode
self.x_mean = x_mean
self.x_std = x_std
self.x_min = x_min
self.x_max = x_max
def fit(self, x):
"""从数据拟合标准化参数"""
if self.mode == 'minmax':
self.x_min = x.min(dim=0, keepdim=True)[0]
self.x_max = x.max(dim=0, keepdim=True)[0]
elif self.mode == 'standard':
self.x_mean = x.mean(dim=0, keepdim=True)
self.x_std = x.std(dim=0, keepdim=True)
def transform(self, x):
"""应用标准化"""
if self.mode == 'minmax':
x_std = (x - self.x_min) / (self.x_max - self.x_min + 1e-8)
return x_std.clamp(0, 1)
elif self.mode == 'standard':
return (x - self.x_mean) / (self.x_std + 1e-8)
def inverse_transform(self, x):
"""反向标准化"""
if self.mode == 'minmax':
return x * (self.x_max - self.x_min) + self.x_min
elif self.mode == 'standard':
return x * self.x_std + self.x_mean
def create_synthetic_dataset(func, n_samples=1000, noise=0.01,
input_range=(0, 1), seed=42):
"""
创建合成数据集用于测试 KAN
Args:
func: 目标函数
n_samples: 样本数量
noise: 噪声水平
input_range: 输入范围
seed: 随机种子
"""
torch.manual_seed(seed)
# 生成输入
x = torch.rand(n_samples, 1) * (input_range[1] - input_range[0]) + input_range[0]
# 生成目标
y = func(x) + torch.randn(n_samples, 1) * noise
return x, y3.2 训练器
class KANTrainer:
"""
KAN 训练器
"""
def __init__(self, model, optimizer=None, scheduler=None,
device='cpu', lambda_l1=0.0, lambda_l2=0.0):
self.model = model
self.device = device
self.model.to(device)
# 优化器
if optimizer is None:
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-3,
weight_decay=lambda_l2
)
else:
self.optimizer = optimizer
# 学习率调度器
self.scheduler = scheduler
# 正则化权重
self.lambda_l1 = lambda_l1
def train_epoch(self, train_loader, verbose=True):
"""训练一个 epoch"""
self.model.train()
total_loss = 0
n_batches = 0
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
pred = self.model(x)
loss = F.mse_loss(pred, y)
# 添加正则化
reg_loss = self.model.get_regularization_loss(self.lambda_l1)
total_loss_batch = loss + reg_loss
total_loss_batch.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
total_loss += loss.item()
n_batches += 1
if self.scheduler is not None:
self.scheduler.step()
return total_loss / n_batches
def evaluate(self, val_loader):
"""在验证集上评估"""
self.model.eval()
total_loss = 0
n_batches = 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = F.mse_loss(pred, y)
total_loss += loss.item()
n_batches += 1
return total_loss / n_batches
def fit(self, train_loader, val_loader=None, epochs=200,
early_stopping_patience=50, verbose=True):
"""
完整训练流程
"""
best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': []}
for epoch in range(epochs):
train_loss = self.train_epoch(train_loader, verbose=False)
history['train_loss'].append(train_loss)
if val_loader is not None:
val_loss = self.evaluate(val_loader)
history['val_loss'].append(val_loss)
# 早停
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if verbose and epoch % 10 == 0:
print(f"Epoch {epoch}: train_loss={train_loss:.6f}, "
f"val_loss={val_loss:.6f}")
if patience_counter >= early_stopping_patience:
print(f"Early stopping at epoch {epoch}")
break
else:
if verbose and epoch % 10 == 0:
print(f"Epoch {epoch}: train_loss={train_loss:.6f}")
return history4. 使用示例
4.1 基本示例
def basic_example():
"""
KAN 基本使用示例
"""
import torch
from torch.utils.data import TensorDataset, DataLoader
# 设置设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 创建数据集
def target_function(x):
"""目标函数:f(x) = sin(πx) * exp(-x/2)"""
return torch.sin(torch.pi * x) * torch.exp(-x / 2)
x_train, y_train = create_synthetic_dataset(
target_function, n_samples=1000, noise=0.01
)
x_val, y_val = create_synthetic_dataset(
target_function, n_samples=200, noise=0.01, seed=123
)
# 标准化
standardizer = Standardizer(mode='minmax')
standardizer.fit(x_train)
x_train = standardizer.transform(x_train)
x_val = standardizer.transform(x_val)
# 创建数据加载器
train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
# 创建模型
model = KAN(
layer_dims=[1, 8, 8, 1], # 输入 -> 隐藏 -> 隐藏 -> 输出
grid_size=5,
spline_order=3
)
print(f"模型参数数量: {model.get_num_parameters()}")
# 创建训练器
trainer = KANTrainer(
model,
optimizer=torch.optim.AdamW(model.parameters(), lr=1e-3),
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
trainer.optimizer if hasattr(trainer, 'optimizer') else None,
T_max=200
),
lambda_l1=0.001
)
# 训练
history = trainer.fit(train_loader, val_loader, epochs=200)
# 评估
model.eval()
with torch.no_grad():
x_test, y_test = create_synthetic_dataset(target_function, n_samples=500, seed=456)
x_test = standardizer.transform(x_test)
pred = model(x_test)
test_mse = F.mse_loss(pred, y_test).item()
print(f"测试 MSE: {test_mse:.6f}")
return model, standardizer
def multi_dimensional_example():
"""
多维函数示例
"""
import torch
from torch.utils.data import TensorDataset, DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def target_function_2d(x):
"""二维目标函数: f(x1, x2) = sin(x1) * cos(x2)"""
return torch.sin(x[:, 0:1]) * torch.cos(x[:, 1:2])
# 生成数据
torch.manual_seed(42)
x_train = torch.rand(1000, 2)
y_train = target_function_2d(x_train) + torch.randn(1000, 1) * 0.01
x_val = torch.rand(200, 2)
y_val = target_function_2d(x_val) + torch.randn(200, 1) * 0.01
# 标准化
standardizer = Standardizer(mode='minmax')
standardizer.fit(x_train)
x_train = standardizer.transform(x_train)
x_val = standardizer.transform(x_val)
# 数据加载器
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 模型
model = KAN(
layer_dims=[2, 16, 16, 1],
grid_size=5,
spline_order=3
)
# 训练
trainer = KANTrainer(model, lambda_l1=0.001)
history = trainer.fit(train_loader, epochs=300)
return model, standardizer4.2 可解释性示例
def interpretability_example():
"""
KAN 可解释性示例
"""
import matplotlib.pyplot as plt
# 训练一个简单的 KAN
model, standardizer = basic_example()
# 获取激活重要性
importance = model.get_regularization_loss()
print(f"正则化损失: {importance:.6f}")
# 可视化第一层的激活函数
layer = model.layers[0]
coeff = layer.spline.coeff.detach().cpu().numpy()
print(f"\n第一层激活函数系数形状: {coeff.shape}")
print(f"系数范围: [{coeff.min():.4f}, {coeff.max():.4f}]")
# 可视化
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
x = np.linspace(0, 1, 100)
for i in range(min(8, coeff.shape[0])): # 输出维度
ax = axes[i // 4, i % 4]
# 绘制每个输入维度的激活
for j in range(coeff.shape[1]): # 输入维度
y = coeff[i, j] # 一维激活函数
ax.plot(y, alpha=0.7, label=f'input {j}')
ax.set_title(f'Output {i}')
ax.legend()
plt.tight_layout()
plt.savefig('kan_activations.png')
plt.close()
print("\n激活函数图已保存到 kan_activations.png")
def sparsity_example():
"""
稀疏 KAN 示例
"""
# 创建稀疏 KAN
model = SparseKAN(
layer_dims=[2, 8, 1],
grid_size=5,
spline_order=3,
sparsity_threshold=0.01
)
# 训练(此处省略)
# ...
# 应用稀疏性
model.apply_sparsity()
# 检查稀疏性
total_params = 0
zero_params = 0
for layer in model.kan.layers:
coeff = layer.spline.coeff
total_params += coeff.numel()
zero_params += (coeff == 0).sum().item()
sparsity = zero_params / total_params
print(f"稀疏性: {sparsity:.2%}")5. 效率优化技巧
5.1 计算优化
class OptimizedKAN:
"""
优化过的 KAN 实现
"""
@staticmethod
def vectorized_bspline(x, coeff, grid, k):
"""
向量化的 B-样条计算
"""
# 简化的线性插值实现
batch_size, in_features = x.shape
out_features, _, n_bases = coeff.shape
# 归一化输入
x = x.clamp(0, 1)
# 网格索引
grid_size = n_bases
indices = (x * (grid_size - 1)).long()
indices = indices.clamp(0, grid_size - 2)
# 插值权重
t = (x * (grid_size - 1)) - indices.float()
t = t.unsqueeze(1) # (batch, 1, in)
# 获取低高值
coeff_low = coeff[:, :, indices] # (out, in, batch)
coeff_high = coeff[:, :, indices + 1]
# 线性插值
coeff_low = coeff_low.permute(2, 0, 1) # (batch, out, in)
coeff_high = coeff_high.permute(2, 0, 1)
output = (1 - t) * coeff_low + t * coeff_high
output = output.sum(dim=-1) # (batch, out)
return output
class BatchedKANInference:
"""
批处理推理优化
"""
def __init__(self, model, batch_size=32):
self.model = model
self.batch_size = batch_size
def predict(self, x):
"""高效的批处理预测"""
self.model.eval()
outputs = []
with torch.no_grad():
for i in range(0, len(x), self.batch_size):
batch = x[i:i + self.batch_size]
output = self.model(batch)
outputs.append(output)
return torch.cat(outputs, dim=0)5.2 内存优化
class MemoryEfficientKAN(KAN):
"""
内存高效 KAN(使用梯度检查点)
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3,
checkpoint_every=2):
super().__init__(layer_dims, grid_size, spline_order)
self.checkpoint_every = checkpoint_every
def forward(self, x):
"""使用梯度检查点的前向传播"""
for i, layer in enumerate(self.layers):
if i % self.checkpoint_every == 0 and i > 0:
x = torch.utils.checkpoint.checkpoint(layer, x)
else:
x = layer(x)
if i < self.num_layers - 1:
x = self.activation(x)
return x6. 完整训练脚本
#!/usr/bin/env python3
"""
KAN 完整训练脚本
Usage:
python kan_training_script.py --task regression --epochs 200
"""
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import json
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description='KAN Training Script')
parser.add_argument('--task', type=str, default='regression',
choices=['regression', 'classification'])
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--grid_size', type=int, default=5)
parser.add_argument('--spline_order', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--device', type=str, default='auto')
parser.add_argument('--output_dir', type=str, default='./output')
return parser.parse_args()
def main():
args = parse_args()
# 设备
if args.device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
print(f"Using device: {device}")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 生成数据
torch.manual_seed(42)
def target(x):
return torch.sin(torch.pi * x) * torch.exp(-x / 2)
x_train = torch.rand(1000, 1)
y_train = target(x_train) + torch.randn(1000, 1) * 0.01
x_val = torch.rand(200, 1)
y_val = target(x_val) + torch.randn(200, 1) * 0.01
# 标准化
standardizer = Standardizer(mode='minmax')
standardizer.fit(x_train)
x_train = standardizer.transform(x_train)
x_val = standardizer.transform(x_val)
# 数据加载器
train_loader = DataLoader(
TensorDataset(x_train, y_train),
batch_size=args.batch_size,
shuffle=True
)
val_loader = DataLoader(
TensorDataset(x_val, y_val),
batch_size=args.batch_size
)
# 模型
model = KAN(
layer_dims=[1, 8, 8, 1],
grid_size=args.grid_size,
spline_order=args.spline_order
).to(device)
print(f"模型参数数量: {model.get_num_parameters()}")
# 训练器
trainer = KANTrainer(
model,
optimizer=torch.optim.AdamW(model.parameters(), lr=args.lr),
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
trainer.optimizer if hasattr(trainer, 'optimizer') else None,
T_max=args.epochs
),
lambda_l1=0.001
)
# 训练
history = trainer.fit(train_loader, val_loader, epochs=args.epochs)
# 保存结果
results = {
'config': vars(args),
'final_train_loss': history['train_loss'][-1],
'final_val_loss': history['val_loss'][-1],
'best_val_loss': min(history['val_loss']),
'history': history
}
with open(output_dir / 'results.json', 'w') as f:
json.dump(results, f, indent=2)
# 保存模型
torch.save({
'model_state_dict': model.state_dict(),
'standardizer': standardizer,
'config': vars(args)
}, output_dir / 'model.pt')
print(f"\n结果已保存到 {output_dir}")
print(f"最终验证损失: {results['final_val_loss']:.6f}")
if __name__ == '__main__':
main()7. 总结
本文提供了 KAN 的完整 PyTorch 实现,包括:
- 基础组件:B-样条激活函数、KAN Layer
- 完整模型:基础 KAN、带跳跃连接的 KAN、稀疏 KAN
- 训练工具:标准化器、训练器
- 使用示例:基本回归、多维函数、可解释性
- 优化技巧:计算优化、内存优化
这些实现可以作为进一步研究和应用的基础。
参考
- Liu, Z., et al. (2024). “KAN: Kolmogorov-Arnold Networks”. arXiv:2404.19756.
相关阅读
- 现代MLP架构 — KAN基础介绍
- KAN 2.0科学发现 — KAN 2.0深度解析
- KAN变体综述 — 各种KAN变体对比
- KAN批判性评估 — 理论与实践的差距
- KAN理论分析 — 表达能力、训练动力学