概述

本文提供Kolmogorov-Arnold Networks (KAN)的完整PyTorch实现,包括基础KAN Layer、完整KAN模型、训练工具和最佳实践指南。代码经过模块化设计,便于理解和扩展。


1. 基础组件实现

1.1 B-样条激活函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
 
 
class BSplineActivation(nn.Module):
    """
    B-样条激活函数
    
    实现 KAN 中使用的可学习 B-样条激活
    """
    
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        
        # 定义网格范围
        self.grid_range = (0, 1)
        
        # 创建网格点
        h = (self.grid_range[1] - self.grid_range[0]) / grid_size
        self.register_buffer(
            'grid', 
            torch.linspace(
                self.grid_range[0] - h * spline_order,
                self.grid_range[1] + h * spline_order,
                grid_size + 2 * spline_order + 1
            )
        )
        
        # B-样条系数(可学习参数)
        self.coeff = nn.Parameter(
            torch.randn(out_features, in_features, grid_size + spline_order)
        )
        
        # 初始化
        self._init_coeff()
    
    def _init_coeff(self):
        """初始化系数为零"""
        nn.init.zeros_(self.coeff)
    
    def de_boor(self, x, grid, k):
        """
        De Boor 算法计算 B-样条值
        
        Args:
            x: (batch, in_features) 输入点
            grid: 网格点
            k: 样条阶数
        
        Returns:
            bases: (batch, in_features, n_bases) 基函数值
        """
        batch_size, in_features = x.shape
        
        # 确保 x 在网格范围内
        x = x.clamp(self.grid_range[0], self.grid_range[1])
        
        # 计算 B-样条基函数
        n_bases = len(grid) - k - 1
        bases = torch.zeros(batch_size, in_features, n_bases, device=x.device)
        
        # 0阶基函数
        for i in range(n_bases):
            left = (grid[i] <= x.float()) & (x.float() < grid[i + 1])
            right = (x.float() == grid[-1]) & (i == n_bases - 1)
            bases[:, :, i] = (left | right).float()
        
        # 递归计算高阶基函数
        for order in range(1, k + 1):
            for i in range(n_bases - order):
                # 左侧项
                denom_left = grid[i + order] - grid[i] + 1e-8
                left = (x - grid[i]) / denom_left * bases[:, :, i]
                
                # 右侧项
                denom_right = grid[i + order + 1] - grid[i + 1] + 1e-8
                right = (grid[i + order + 1] - x) / denom_right * bases[:, :, i + 1]
                
                bases[:, :, i] = left + right
            
            bases = bases[:, :, :n_bases - order]
        
        return bases
    
    def forward(self, x):
        """
        前向传播
        
        Args:
            x: (batch, in_features) 输入,范围 [0, 1]
        
        Returns:
            y: (batch, out_features)
        """
        # 计算 B-样条基函数
        bases = self.de_boor(x, self.grid, self.spline_order)
        
        # 加权求和
        # bases: (batch, in, n_bases)
        # coeff: (out, in, n_bases)
        # output: (batch, out)
        output = torch.einsum('bik,oik->bo', bases, self.coeff)
        
        return output
 
 
class EfficientBSplineActivation(nn.Module):
    """
    高效 B-样条激活(使用向量化的实现)
    """
    
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        
        # 网格
        h = 1.0 / grid_size
        self.register_buffer(
            'grid',
            torch.linspace(-h * spline_order, 1 + h * spline_order, 
                          grid_size + 2 * spline_order + 1)
        )
        
        # 系数
        self.coeff = nn.Parameter(
            torch.randn(out_features, in_features, grid_size + spline_order)
        )
        
        self._init_coeff()
    
    def _init_coeff(self):
        nn.init.zeros_(self.coeff)
    
    def forward(self, x):
        """
        高效的前向传播
        """
        batch_size, in_features = x.shape
        
        # 映射到 [0, 1]
        x = x.clamp(0, 1)
        
        # 简化的 B-样条计算
        # 使用一阶(线性)样条作为示例
        grid = self.grid
        k = self.spline_order
        
        # 计算插值权重
        x_expanded = x.unsqueeze(-1)  # (batch, in, 1)
        
        # 找到 x 在网格中的位置
        indices = torch.searchsorted(grid[1:-1], x_expanded)  # (batch, in, 1)
        indices = indices.clamp(0, len(grid) - 2)
        
        # 计算权重
        x1 = grid[indices]  # (batch, in, 1)
        x2 = grid[indices + 1]
        
        t = (x_expanded - x1) / (x2 - x1 + 1e-8)
        t = t.squeeze(-1)  # (batch, in)
        
        # 计算线性插值
        coeff_low = self.coeff[:, :, :self.grid_size]
        coeff_high = self.coeff[:, :, 1:self.grid_size + 1]
        
        spline_output = (1 - t.unsqueeze(1)) * coeff_low + t.unsqueeze(1) * coeff_high
        spline_output = spline_output.sum(dim=-1)  # (batch, out)
        
        return spline_output

1.2 KAN Layer

class KANLayer(nn.Module):
    """
    KAN Layer
    
    KAN 的基本构建块
    """
    
    def __init__(self, in_features, out_features, 
                 grid_size=5, spline_order=3, 
                 base_activation='silu',
                 use_base_activation=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        
        # B-样条激活
        self.spline = BSplineActivation(
            in_features, out_features, 
            grid_size, spline_order
        )
        
        # 基础激活权重
        if use_base_activation:
            self.base_weight = nn.Parameter(
                torch.randn(out_features, in_features)
            )
        else:
            self.register_parameter('base_weight', None)
        
        self.use_base_activation = use_base_activation
        
        # 激活函数
        if base_activation == 'silu':
            self.activation_fn = nn.functional.silu
        elif base_activation == 'gelu':
            self.activation_fn = nn.functional.gelu
        elif base_activation == 'relu':
            self.activation_fn = nn.functional.relu
        elif base_activation == 'tanh':
            self.activation_fn = torch.tanh
        else:
            self.activation_fn = nn.functional.silu
        
        # 初始化
        self._init_weights()
    
    def _init_weights(self):
        """初始化权重"""
        if self.base_weight is not None:
            nn.init.normal_(self.base_weight, std=0.1)
    
    def forward(self, x):
        """
        前向传播
        
        Args:
            x: (batch, in_features) 输入
        
        Returns:
            y: (batch, out_features) 输出
        """
        # 确保输入在 [0, 1]
        x = x.clamp(0, 1)
        
        # B-样条激活
        spline_out = self.spline(x)
        
        # 基础激活
        if self.use_base_activation and self.base_weight is not None:
            base_out = torch.einsum('bi,oi->bo', 
                                    self.activation_fn(x), 
                                    self.base_weight)
            return spline_out + base_out
        
        return spline_out
 
 
class KANLinear(nn.Module):
    """
    KAN 的线性层版本(用于特征维度变换)
    """
    
    def __init__(self, in_features, out_features,
                 grid_size=5, spline_order=3):
        super().__init__()
        
        self.kan = KANLayer(in_features, out_features, grid_size, spline_order)
    
    def forward(self, x):
        return self.kan(x)

2. 完整 KAN 模型

2.1 基础 KAN

class KAN(nn.Module):
    """
    完整的 Kolmogorov-Arnold Network
    
    Args:
        layer_dims: 每层的维度列表,如 [2, 3, 5, 1]
        grid_size: B-样条网格大小
        spline_order: B-样条阶数
        base_activation: 基础激活函数
    """
    
    def __init__(self, layer_dims, grid_size=5, spline_order=3,
                 base_activation='silu', use_base_activation=True):
        super().__init__()
        self.layer_dims = layer_dims
        self.num_layers = len(layer_dims) - 1
        
        # 创建层
        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            self.layers.append(
                KANLayer(
                    in_features=layer_dims[i],
                    out_features=layer_dims[i + 1],
                    grid_size=grid_size,
                    spline_order=spline_order,
                    base_activation=base_activation,
                    use_base_activation=use_base_activation
                )
            )
        
        # 激活函数(层间)
        self.activation = nn.SiLU() if base_activation == 'silu' else nn.GELU()
    
    def forward(self, x):
        """
        前向传播
        
        Args:
            x: (batch, in_features) 或 (batch, seq_len, in_features)
        
        Returns:
            y: (batch, out_features) 或 (batch, seq_len, out_features)
        """
        # 处理不同输入形状
        original_shape = x.shape
        if len(original_shape) == 3:
            # (batch, seq_len, features) -> (batch*seq_len, features)
            batch_size, seq_len, in_features = x.shape
            x = x.reshape(batch_size * seq_len, in_features)
            need_reshape = True
        else:
            need_reshape = False
        
        # 前向传播
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < self.num_layers - 1:
                x = self.activation(x)
        
        # 恢复形状
        if need_reshape:
            x = x.reshape(batch_size, seq_len, -1)
        
        return x
    
    def get_regularization_loss(self, lambda_l1=0.01):
        """
        计算 L1 正则化损失(促进稀疏性)
        """
        loss = 0
        for layer in self.layers:
            loss += lambda_l1 * torch.abs(layer.spline.coeff).mean()
        return loss
    
    def get_num_parameters(self):
        """获取参数数量"""
        return sum(p.numel() for p in self.parameters())

2.2 带跳跃连接的 KAN

class ResidualKAN(nn.Module):
    """
    带跳跃连接的 KAN
    
    类似于 ResNet,提高训练稳定性
    """
    
    def __init__(self, layer_dims, grid_size=5, spline_order=3,
                 base_activation='silu', dropout=0.0):
        super().__init__()
        
        self.kan = KAN(
            layer_dims, grid_size, spline_order, 
            base_activation, use_base_activation=True
        )
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        
        # 跳跃连接(如果维度不匹配)
        if layer_dims[0] != layer_dims[-1]:
            self.skip = nn.Linear(layer_dims[0], layer_dims[-1])
        else:
            self.skip = nn.Identity()
        
        # 门控
        self.gate = nn.Parameter(torch.tensor(0.5))
    
    def forward(self, x):
        identity = self.skip(x)
        
        out = self.kan(x)
        
        if self.dropout is not None:
            out = self.dropout(out)
        
        # 门控跳跃连接
        out = self.gate * out + (1 - self.gate) * identity
        
        return out
 
 
class DeepKAN(nn.Module):
    """
    深度 KAN(带层级归一化)
    """
    
    def __init__(self, layer_dims, grid_size=5, spline_order=3,
                 base_activation='silu', dropout=0.1):
        super().__init__()
        
        self.layers = nn.ModuleList()
        self_norms = nn.ModuleList()
        
        for i in range(len(layer_dims) - 1):
            self.layers.append(
                KANLayer(
                    layer_dims[i], layer_dims[i + 1],
                    grid_size, spline_order, base_activation
                )
            )
            self_norms.append(
                nn.LayerNorm(layer_dims[i + 1])
            )
        
        self.norms = nn.ModuleList(self_norms)
        self.activation = nn.SiLU()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = norm(layer(x))
            x = self.activation(x)
            x = self.dropout(x)
        
        return x

2.3 可解释性增强的 KAN

class SparseKAN(nn.Module):
    """
    稀疏 KAN(促进可解释性)
    
    通过 L1 正则化和阈值化实现稀疏激活
    """
    
    def __init__(self, layer_dims, grid_size=5, spline_order=3,
                 sparsity_threshold=0.01):
        super().__init__()
        
        self.kan = KAN(layer_dims, grid_size, spline_order)
        self.sparsity_threshold = sparsity_threshold
    
    def forward(self, x):
        return self.kan(x)
    
    def apply_sparsity(self):
        """
        应用稀疏性:将小系数置零
        """
        with torch.no_grad():
            for layer in self.kan.layers:
                layer.spline.coeff[
                    torch.abs(layer.spline.coeff) < self.sparsity_threshold
                ] = 0
    
    def get_activation_importance(self):
        """
        获取激活函数的重要性分数
        """
        importance = {}
        for i, layer in enumerate(self.kan.layers):
            coeff_abs_mean = torch.abs(layer.spline.coeff).mean(dim=(0, 2))
            importance[f'layer_{i}'] = coeff_abs_mean.cpu().numpy()
        return importance
 
 
class ModularKAN(nn.Module):
    """
    模块化 KAN(支持可解释性分组)
    """
    
    def __init__(self, module_configs, grid_size=5, spline_order=3):
        """
        Args:
            module_configs: 列表,每个元素是 (in_features, out_features, name)
        """
        super().__init__()
        
        self.modules = nn.ModuleDict()
        for in_f, out_f, name in module_configs:
            self.modules[name] = KANLayer(in_f, out_f, grid_size, spline_order)
    
    def forward(self, x_dict):
        """
        x_dict: 字典,键是模块名,值是输入张量
        """
        outputs = {}
        for name, module in self.modules.items():
            if name in x_dict:
                outputs[name] = module(x_dict[name])
        return outputs
    
    def visualize_activations(self):
        """
        可视化激活函数
        """
        for name, module in self.modules.items():
            print(f"\nModule: {name}")
            print(f"  Coefficient shape: {module.spline.coeff.shape}")
            print(f"  Coefficient range: [{module.spline.coeff.min():.4f}, "
                  f"{module.spline.coeff.max():.4f}]")

3. 训练工具

3.1 数据集和标准化

class Standardizer:
    """
    数据标准化工具
    
    将数据标准化到 [0, 1] 范围(KAN 需要)
    """
    
    def __init__(self, x_mean=None, x_std=None, 
                 x_min=None, x_max=None,
                 mode='minmax'):
        self.mode = mode
        self.x_mean = x_mean
        self.x_std = x_std
        self.x_min = x_min
        self.x_max = x_max
    
    def fit(self, x):
        """从数据拟合标准化参数"""
        if self.mode == 'minmax':
            self.x_min = x.min(dim=0, keepdim=True)[0]
            self.x_max = x.max(dim=0, keepdim=True)[0]
        elif self.mode == 'standard':
            self.x_mean = x.mean(dim=0, keepdim=True)
            self.x_std = x.std(dim=0, keepdim=True)
    
    def transform(self, x):
        """应用标准化"""
        if self.mode == 'minmax':
            x_std = (x - self.x_min) / (self.x_max - self.x_min + 1e-8)
            return x_std.clamp(0, 1)
        elif self.mode == 'standard':
            return (x - self.x_mean) / (self.x_std + 1e-8)
    
    def inverse_transform(self, x):
        """反向标准化"""
        if self.mode == 'minmax':
            return x * (self.x_max - self.x_min) + self.x_min
        elif self.mode == 'standard':
            return x * self.x_std + self.x_mean
 
 
def create_synthetic_dataset(func, n_samples=1000, noise=0.01,
                           input_range=(0, 1), seed=42):
    """
    创建合成数据集用于测试 KAN
    
    Args:
        func: 目标函数
        n_samples: 样本数量
        noise: 噪声水平
        input_range: 输入范围
        seed: 随机种子
    """
    torch.manual_seed(seed)
    
    # 生成输入
    x = torch.rand(n_samples, 1) * (input_range[1] - input_range[0]) + input_range[0]
    
    # 生成目标
    y = func(x) + torch.randn(n_samples, 1) * noise
    
    return x, y

3.2 训练器

class KANTrainer:
    """
    KAN 训练器
    """
    
    def __init__(self, model, optimizer=None, scheduler=None,
                 device='cpu', lambda_l1=0.0, lambda_l2=0.0):
        self.model = model
        self.device = device
        self.model.to(device)
        
        # 优化器
        if optimizer is None:
            self.optimizer = torch.optim.AdamW(
                model.parameters(), 
                lr=1e-3, 
                weight_decay=lambda_l2
            )
        else:
            self.optimizer = optimizer
        
        # 学习率调度器
        self.scheduler = scheduler
        
        # 正则化权重
        self.lambda_l1 = lambda_l1
    
    def train_epoch(self, train_loader, verbose=True):
        """训练一个 epoch"""
        self.model.train()
        total_loss = 0
        n_batches = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(self.device), y.to(self.device)
            
            self.optimizer.zero_grad()
            
            pred = self.model(x)
            loss = F.mse_loss(pred, y)
            
            # 添加正则化
            reg_loss = self.model.get_regularization_loss(self.lambda_l1)
            total_loss_batch = loss + reg_loss
            
            total_loss_batch.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            
            self.optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
        
        if self.scheduler is not None:
            self.scheduler.step()
        
        return total_loss / n_batches
    
    def evaluate(self, val_loader):
        """在验证集上评估"""
        self.model.eval()
        total_loss = 0
        n_batches = 0
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(self.device), y.to(self.device)
                pred = self.model(x)
                loss = F.mse_loss(pred, y)
                total_loss += loss.item()
                n_batches += 1
        
        return total_loss / n_batches
    
    def fit(self, train_loader, val_loader=None, epochs=200,
            early_stopping_patience=50, verbose=True):
        """
        完整训练流程
        """
        best_val_loss = float('inf')
        patience_counter = 0
        history = {'train_loss': [], 'val_loss': []}
        
        for epoch in range(epochs):
            train_loss = self.train_epoch(train_loader, verbose=False)
            history['train_loss'].append(train_loss)
            
            if val_loader is not None:
                val_loss = self.evaluate(val_loader)
                history['val_loss'].append(val_loss)
                
                # 早停
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                if verbose and epoch % 10 == 0:
                    print(f"Epoch {epoch}: train_loss={train_loss:.6f}, "
                          f"val_loss={val_loss:.6f}")
                
                if patience_counter >= early_stopping_patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
            else:
                if verbose and epoch % 10 == 0:
                    print(f"Epoch {epoch}: train_loss={train_loss:.6f}")
        
        return history

4. 使用示例

4.1 基本示例

def basic_example():
    """
    KAN 基本使用示例
    """
    import torch
    from torch.utils.data import TensorDataset, DataLoader
    
    # 设置设备
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # 创建数据集
    def target_function(x):
        """目标函数:f(x) = sin(πx) * exp(-x/2)"""
        return torch.sin(torch.pi * x) * torch.exp(-x / 2)
    
    x_train, y_train = create_synthetic_dataset(
        target_function, n_samples=1000, noise=0.01
    )
    x_val, y_val = create_synthetic_dataset(
        target_function, n_samples=200, noise=0.01, seed=123
    )
    
    # 标准化
    standardizer = Standardizer(mode='minmax')
    standardizer.fit(x_train)
    x_train = standardizer.transform(x_train)
    x_val = standardizer.transform(x_val)
    
    # 创建数据加载器
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    
    # 创建模型
    model = KAN(
        layer_dims=[1, 8, 8, 1],  # 输入 -> 隐藏 -> 隐藏 -> 输出
        grid_size=5,
        spline_order=3
    )
    print(f"模型参数数量: {model.get_num_parameters()}")
    
    # 创建训练器
    trainer = KANTrainer(
        model,
        optimizer=torch.optim.AdamW(model.parameters(), lr=1e-3),
        scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
            trainer.optimizer if hasattr(trainer, 'optimizer') else None, 
            T_max=200
        ),
        lambda_l1=0.001
    )
    
    # 训练
    history = trainer.fit(train_loader, val_loader, epochs=200)
    
    # 评估
    model.eval()
    with torch.no_grad():
        x_test, y_test = create_synthetic_dataset(target_function, n_samples=500, seed=456)
        x_test = standardizer.transform(x_test)
        pred = model(x_test)
        test_mse = F.mse_loss(pred, y_test).item()
        print(f"测试 MSE: {test_mse:.6f}")
    
    return model, standardizer
 
 
def multi_dimensional_example():
    """
    多维函数示例
    """
    import torch
    from torch.utils.data import TensorDataset, DataLoader
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def target_function_2d(x):
        """二维目标函数: f(x1, x2) = sin(x1) * cos(x2)"""
        return torch.sin(x[:, 0:1]) * torch.cos(x[:, 1:2])
    
    # 生成数据
    torch.manual_seed(42)
    x_train = torch.rand(1000, 2)
    y_train = target_function_2d(x_train) + torch.randn(1000, 1) * 0.01
    
    x_val = torch.rand(200, 2)
    y_val = target_function_2d(x_val) + torch.randn(200, 1) * 0.01
    
    # 标准化
    standardizer = Standardizer(mode='minmax')
    standardizer.fit(x_train)
    x_train = standardizer.transform(x_train)
    x_val = standardizer.transform(x_val)
    
    # 数据加载器
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    # 模型
    model = KAN(
        layer_dims=[2, 16, 16, 1],
        grid_size=5,
        spline_order=3
    )
    
    # 训练
    trainer = KANTrainer(model, lambda_l1=0.001)
    history = trainer.fit(train_loader, epochs=300)
    
    return model, standardizer

4.2 可解释性示例

def interpretability_example():
    """
    KAN 可解释性示例
    """
    import matplotlib.pyplot as plt
    
    # 训练一个简单的 KAN
    model, standardizer = basic_example()
    
    # 获取激活重要性
    importance = model.get_regularization_loss()
    print(f"正则化损失: {importance:.6f}")
    
    # 可视化第一层的激活函数
    layer = model.layers[0]
    coeff = layer.spline.coeff.detach().cpu().numpy()
    
    print(f"\n第一层激活函数系数形状: {coeff.shape}")
    print(f"系数范围: [{coeff.min():.4f}, {coeff.max():.4f}]")
    
    # 可视化
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    x = np.linspace(0, 1, 100)
    
    for i in range(min(8, coeff.shape[0])):  # 输出维度
        ax = axes[i // 4, i % 4]
        
        # 绘制每个输入维度的激活
        for j in range(coeff.shape[1]):  # 输入维度
            y = coeff[i, j]  # 一维激活函数
            ax.plot(y, alpha=0.7, label=f'input {j}')
        
        ax.set_title(f'Output {i}')
        ax.legend()
    
    plt.tight_layout()
    plt.savefig('kan_activations.png')
    plt.close()
    
    print("\n激活函数图已保存到 kan_activations.png")
 
 
def sparsity_example():
    """
    稀疏 KAN 示例
    """
    # 创建稀疏 KAN
    model = SparseKAN(
        layer_dims=[2, 8, 1],
        grid_size=5,
        spline_order=3,
        sparsity_threshold=0.01
    )
    
    # 训练(此处省略)
    # ...
    
    # 应用稀疏性
    model.apply_sparsity()
    
    # 检查稀疏性
    total_params = 0
    zero_params = 0
    
    for layer in model.kan.layers:
        coeff = layer.spline.coeff
        total_params += coeff.numel()
        zero_params += (coeff == 0).sum().item()
    
    sparsity = zero_params / total_params
    print(f"稀疏性: {sparsity:.2%}")

5. 效率优化技巧

5.1 计算优化

class OptimizedKAN:
    """
    优化过的 KAN 实现
    """
    
    @staticmethod
    def vectorized_bspline(x, coeff, grid, k):
        """
        向量化的 B-样条计算
        """
        # 简化的线性插值实现
        batch_size, in_features = x.shape
        out_features, _, n_bases = coeff.shape
        
        # 归一化输入
        x = x.clamp(0, 1)
        
        # 网格索引
        grid_size = n_bases
        indices = (x * (grid_size - 1)).long()
        indices = indices.clamp(0, grid_size - 2)
        
        # 插值权重
        t = (x * (grid_size - 1)) - indices.float()
        t = t.unsqueeze(1)  # (batch, 1, in)
        
        # 获取低高值
        coeff_low = coeff[:, :, indices]  # (out, in, batch)
        coeff_high = coeff[:, :, indices + 1]
        
        # 线性插值
        coeff_low = coeff_low.permute(2, 0, 1)  # (batch, out, in)
        coeff_high = coeff_high.permute(2, 0, 1)
        
        output = (1 - t) * coeff_low + t * coeff_high
        output = output.sum(dim=-1)  # (batch, out)
        
        return output
 
 
class BatchedKANInference:
    """
    批处理推理优化
    """
    
    def __init__(self, model, batch_size=32):
        self.model = model
        self.batch_size = batch_size
    
    def predict(self, x):
        """高效的批处理预测"""
        self.model.eval()
        outputs = []
        
        with torch.no_grad():
            for i in range(0, len(x), self.batch_size):
                batch = x[i:i + self.batch_size]
                output = self.model(batch)
                outputs.append(output)
        
        return torch.cat(outputs, dim=0)

5.2 内存优化

class MemoryEfficientKAN(KAN):
    """
    内存高效 KAN(使用梯度检查点)
    """
    
    def __init__(self, layer_dims, grid_size=5, spline_order=3,
                 checkpoint_every=2):
        super().__init__(layer_dims, grid_size, spline_order)
        self.checkpoint_every = checkpoint_every
    
    def forward(self, x):
        """使用梯度检查点的前向传播"""
        for i, layer in enumerate(self.layers):
            if i % self.checkpoint_every == 0 and i > 0:
                x = torch.utils.checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
            
            if i < self.num_layers - 1:
                x = self.activation(x)
        
        return x

6. 完整训练脚本

#!/usr/bin/env python3
"""
KAN 完整训练脚本
 
Usage:
    python kan_training_script.py --task regression --epochs 200
"""
 
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import json
from pathlib import Path
 
 
def parse_args():
    parser = argparse.ArgumentParser(description='KAN Training Script')
    parser.add_argument('--task', type=str, default='regression',
                       choices=['regression', 'classification'])
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--grid_size', type=int, default=5)
    parser.add_argument('--spline_order', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--device', type=str, default='auto')
    parser.add_argument('--output_dir', type=str, default='./output')
    return parser.parse_args()
 
 
def main():
    args = parse_args()
    
    # 设备
    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device
    
    print(f"Using device: {device}")
    
    # 创建输出目录
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 生成数据
    torch.manual_seed(42)
    
    def target(x):
        return torch.sin(torch.pi * x) * torch.exp(-x / 2)
    
    x_train = torch.rand(1000, 1)
    y_train = target(x_train) + torch.randn(1000, 1) * 0.01
    
    x_val = torch.rand(200, 1)
    y_val = target(x_val) + torch.randn(200, 1) * 0.01
    
    # 标准化
    standardizer = Standardizer(mode='minmax')
    standardizer.fit(x_train)
    x_train = standardizer.transform(x_train)
    x_val = standardizer.transform(x_val)
    
    # 数据加载器
    train_loader = DataLoader(
        TensorDataset(x_train, y_train), 
        batch_size=args.batch_size, 
        shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(x_val, y_val), 
        batch_size=args.batch_size
    )
    
    # 模型
    model = KAN(
        layer_dims=[1, 8, 8, 1],
        grid_size=args.grid_size,
        spline_order=args.spline_order
    ).to(device)
    
    print(f"模型参数数量: {model.get_num_parameters()}")
    
    # 训练器
    trainer = KANTrainer(
        model,
        optimizer=torch.optim.AdamW(model.parameters(), lr=args.lr),
        scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
            trainer.optimizer if hasattr(trainer, 'optimizer') else None,
            T_max=args.epochs
        ),
        lambda_l1=0.001
    )
    
    # 训练
    history = trainer.fit(train_loader, val_loader, epochs=args.epochs)
    
    # 保存结果
    results = {
        'config': vars(args),
        'final_train_loss': history['train_loss'][-1],
        'final_val_loss': history['val_loss'][-1],
        'best_val_loss': min(history['val_loss']),
        'history': history
    }
    
    with open(output_dir / 'results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    # 保存模型
    torch.save({
        'model_state_dict': model.state_dict(),
        'standardizer': standardizer,
        'config': vars(args)
    }, output_dir / 'model.pt')
    
    print(f"\n结果已保存到 {output_dir}")
    print(f"最终验证损失: {results['final_val_loss']:.6f}")
 
 
if __name__ == '__main__':
    main()

7. 总结

本文提供了 KAN 的完整 PyTorch 实现,包括:

  1. 基础组件:B-样条激活函数、KAN Layer
  2. 完整模型:基础 KAN、带跳跃连接的 KAN、稀疏 KAN
  3. 训练工具:标准化器、训练器
  4. 使用示例:基本回归、多维函数、可解释性
  5. 优化技巧:计算优化、内存优化

这些实现可以作为进一步研究和应用的基础。


参考

  • Liu, Z., et al. (2024). “KAN: Kolmogorov-Arnold Networks”. arXiv:2404.19756.

相关阅读