SLERP:球面线性插值

1. 几何背景

1.1 问题背景

简单的线性插值(LERP)在参数空间中可能穿过高损失区域。SLERP(Spherical Linear Interpolation)通过在超球面上进行插值来缓解这一问题。

1.2 直观理解

在二维平面上,两点之间的最短路径是直线(测地线)。在球面上,最短路径是大圆弧。SLERP就是在高维参数空间中沿大圆弧进行插值。

2. 数学推导

2.1 基础公式

对于两个归一化向量 ,它们之间的夹角为

SLERP定义为:

2.2 插值性质

时:
时:
时:得到等距中点

3. 归一化策略

3.1 原始幅度保留

直接应用SLERP会丢失幅度信息。一种方法是保留原始幅度:

3.2 平均幅度缩放

另一种策略是使用平均幅度:

4. 局限性分析

4.1 两模型限制

SLERP原生只支持两模型合并。多模型合并需要链式调用:

但这可能引入顺序依赖性。

4.2 夹角过小问题

时:

退化为线性插值。这本身不是问题,但精度损失需要注意。

4.3 夹角接近

(向量反向)时,,数值不稳定。处理方法:

def slerp_stable(v0, v1, t, eps=1e-7):
    # 检测反向情况
    d = (v0 * v1).sum()
    if d > 1 - eps:  # 夹角接近0
        return (1 - t) * v0 + t * v1
    if d < eps - 1:  # 夹角接近π
        # 在子空间内进行线性插值
        return torch.lerp(v0, -v1, t)
    
    omega = torch.acos(d.clamp(-1 + eps, 1 - eps))
    sin_omega = torch.sin(omega)
    return (torch.sin((1 - t) * omega) / sin_omega * v0 + 
            torch.sin(t * omega) / sin_omega * v1)

5. 与其他方法的对比

方法几何基础支持模型数计算复杂度幅度处理
LERP欧几里得空间任意需额外处理
SLERP超球面2(原生)可保留
TIES参数空间任意需额外处理

6. 多模型SLERP策略

6.1 层次SLERP

def hierarchical_slerp(models, weights):
    """层次化SLERP合并多模型"""
    n = len(models)
    if n == 1:
        return models[0]
    
    # 递归合并
    mid = n // 2
    left = hierarchical_slerp(models[:mid], weights[:mid])
    right = hierarchical_slerp(models[mid:], weights[mid:])
    
    # 计算权重比例
    t = weights[mid] / (weights[mid-1] + weights[mid])
    
    return slerp_stable(left, right, t)

6.2 平均方向SLERP

将所有模型投影到单位球面上,取平均方向:

def mean_direction_slerp(models, weights):
    """基于平均方向的SLERP"""
    # 归一化
    normed = [m / m.norm() for m in models]
    
    # 加权平均方向
    mean_dir = sum(w * n for w, n in zip(weights, normed))
    mean_dir = mean_dir / mean_dir.norm()
    
    # 找到最近模型作为起点
    start_idx = max(range(len(models)), 
                    key=lambda i: (models[i] / models[i].norm()).dot(mean_dir))
    
    return slerp_stable(models[start_idx], mean_dir * models[start_idx].norm(), 
                        sum(weights) / 2)

7. 适用场景

SLERP特别适合以下场景:

  1. 两模型插值:需要精确控制合并比例时
  2. 轨迹探索:研究两模型之间的能力变化
  3. 连续混合:创建模型的能力连续体
  4. 安全检查点:在两版本之间创建平滑过渡

8. PyTorch完整实现

#include <torch/torch.h>
 
class SLERPMerger {
public:
    static torch::Tensor slerp(
        torch::Tensor theta1,
        torch::Tensor theta2,
        double t,
        bool preserve_scale = true
    ) {
        // 记录原始尺度
        double scale1 = theta1.norm().item<double>();
        double scale2 = theta2.norm().item<double>();
        
        // 归一化
        theta1 = theta1 / scale1;
        theta2 = theta2 / scale2;
        
        // 计算夹角
        double cos_omega = (theta1 * theta2).sum().item<double>();
        cos_omega = std::max(-1.0, std::min(1.0, cos_omega));
        double omega = std::acos(cos_omega);
        
        torch::Tensor result;
        
        if (omega < 1e-7) {
            // 接近线性插值
            result = (1 - t) * theta1 + t * theta2;
        } else {
            double sin_omega = std::sin(omega);
            result = (std::sin((1 - t) * omega) / sin_omega) * theta1 +
                     (std::sin(t * omega) / sin_omega) * theta2;
        }
        
        // 恢复尺度
        if (preserve_scale) {
            double avg_scale = (scale1 + scale2) / 2;
            result = result * avg_scale;
        }
        
        return result;
    }
    
    static torch::Tensor merge_models(
        std::vector<torch::Tensor> models,
        double t = 0.5
    ) {
        if (models.size() == 1) {
            return models[0];
        }
        
        torch::Tensor merged = models[0];
        for (size_t i = 1; i < models.size(); i++) {
            merged = slerp(merged, models[i], t);
        }
        return merged;
    }
};

9. 参考资料