SLERP:球面线性插值
1. 几何背景
1.1 问题背景
简单的线性插值(LERP)在参数空间中可能穿过高损失区域。SLERP(Spherical Linear Interpolation)通过在超球面上进行插值来缓解这一问题。
1.2 直观理解
在二维平面上,两点之间的最短路径是直线(测地线)。在球面上,最短路径是大圆弧。SLERP就是在高维参数空间中沿大圆弧进行插值。
2. 数学推导
2.1 基础公式
对于两个归一化向量 和 ,它们之间的夹角为 :
SLERP定义为:
2.2 插值性质
当 时:
当 时:
当 时:得到等距中点
3. 归一化策略
3.1 原始幅度保留
直接应用SLERP会丢失幅度信息。一种方法是保留原始幅度:
3.2 平均幅度缩放
另一种策略是使用平均幅度:
4. 局限性分析
4.1 两模型限制
SLERP原生只支持两模型合并。多模型合并需要链式调用:
但这可能引入顺序依赖性。
4.2 夹角过小问题
当 时:
退化为线性插值。这本身不是问题,但精度损失需要注意。
4.3 夹角接近
当 (向量反向)时,,数值不稳定。处理方法:
def slerp_stable(v0, v1, t, eps=1e-7):
# 检测反向情况
d = (v0 * v1).sum()
if d > 1 - eps: # 夹角接近0
return (1 - t) * v0 + t * v1
if d < eps - 1: # 夹角接近π
# 在子空间内进行线性插值
return torch.lerp(v0, -v1, t)
omega = torch.acos(d.clamp(-1 + eps, 1 - eps))
sin_omega = torch.sin(omega)
return (torch.sin((1 - t) * omega) / sin_omega * v0 +
torch.sin(t * omega) / sin_omega * v1)5. 与其他方法的对比
| 方法 | 几何基础 | 支持模型数 | 计算复杂度 | 幅度处理 |
|---|---|---|---|---|
| LERP | 欧几里得空间 | 任意 | 需额外处理 | |
| SLERP | 超球面 | 2(原生) | 可保留 | |
| TIES | 参数空间 | 任意 | 需额外处理 |
6. 多模型SLERP策略
6.1 层次SLERP
def hierarchical_slerp(models, weights):
"""层次化SLERP合并多模型"""
n = len(models)
if n == 1:
return models[0]
# 递归合并
mid = n // 2
left = hierarchical_slerp(models[:mid], weights[:mid])
right = hierarchical_slerp(models[mid:], weights[mid:])
# 计算权重比例
t = weights[mid] / (weights[mid-1] + weights[mid])
return slerp_stable(left, right, t)6.2 平均方向SLERP
将所有模型投影到单位球面上,取平均方向:
def mean_direction_slerp(models, weights):
"""基于平均方向的SLERP"""
# 归一化
normed = [m / m.norm() for m in models]
# 加权平均方向
mean_dir = sum(w * n for w, n in zip(weights, normed))
mean_dir = mean_dir / mean_dir.norm()
# 找到最近模型作为起点
start_idx = max(range(len(models)),
key=lambda i: (models[i] / models[i].norm()).dot(mean_dir))
return slerp_stable(models[start_idx], mean_dir * models[start_idx].norm(),
sum(weights) / 2)7. 适用场景
SLERP特别适合以下场景:
- 两模型插值:需要精确控制合并比例时
- 轨迹探索:研究两模型之间的能力变化
- 连续混合:创建模型的能力连续体
- 安全检查点:在两版本之间创建平滑过渡
8. PyTorch完整实现
#include <torch/torch.h>
class SLERPMerger {
public:
static torch::Tensor slerp(
torch::Tensor theta1,
torch::Tensor theta2,
double t,
bool preserve_scale = true
) {
// 记录原始尺度
double scale1 = theta1.norm().item<double>();
double scale2 = theta2.norm().item<double>();
// 归一化
theta1 = theta1 / scale1;
theta2 = theta2 / scale2;
// 计算夹角
double cos_omega = (theta1 * theta2).sum().item<double>();
cos_omega = std::max(-1.0, std::min(1.0, cos_omega));
double omega = std::acos(cos_omega);
torch::Tensor result;
if (omega < 1e-7) {
// 接近线性插值
result = (1 - t) * theta1 + t * theta2;
} else {
double sin_omega = std::sin(omega);
result = (std::sin((1 - t) * omega) / sin_omega) * theta1 +
(std::sin(t * omega) / sin_omega) * theta2;
}
// 恢复尺度
if (preserve_scale) {
double avg_scale = (scale1 + scale2) / 2;
result = result * avg_scale;
}
return result;
}
static torch::Tensor merge_models(
std::vector<torch::Tensor> models,
double t = 0.5
) {
if (models.size() == 1) {
return models[0];
}
torch::Tensor merged = models[0];
for (size_t i = 1; i < models.size(); i++) {
merged = slerp(merged, models[i], t);
}
return merged;
}
};9. 参考资料
- Shoemake, K. (1985). Animating rotation with quaternion curves. SIGGRAPH 1985.
- MergeKit: Top 5 Model Merge Methods Compared. https://mergekit.com/blog/top-5-merge-methods-compared