基于SAE的特征操控

概述

Feature Steering(特征操控)是利用SAE学到的可解释特征来精确控制模型行为的技术。与传统的提示工程或微调不同,特征操控直接修改模型的内部激活,从而实现细粒度的行为控制。

核心思想

  • SAE将叠加的表示分解为稀疏、可解释的特征
  • 修改特征激活可以直接影响模型的计算和行为
  • 通过操控特定特征,可以实现情感控制、知识编辑、安全干预等

1. 特征操控基础

1.1 操控原理

原始前向传播:
  输入 → 激活 x → [模型] → 输出

SAE编码后:
  输入 → 激活 x → [SAE编码器] → 特征 f(x)
                                ↓
                        操控特征 f'(x)
                                ↓
  输入 → 激活 x → [SAE解码器] → 修改激活 x' → [模型] → 输出

效果: 模型行为被精确修改

1.2 操控类型

操控类型描述典型应用
增强增加特征激活值增强某种能力
抑制减少特征激活值消除有害内容
替换用另一个值替换改变特征含义
组合同时操控多个特征复杂行为控制

1.3 数学形式

设原始特征为 ,操控后的特征为

增强

抑制

替换


2. 操控方法

2.1 基础操控

import torch
import torch.nn.functional as F
from typing import Optional
 
class SAEFeatureSteering:
    """SAE特征操控器"""
    
    def __init__(self, sae, model):
        self.sae = sae
        self.model = model
        self.device = next(sae.parameters()).device
    
    def enhance_feature(
        self,
        x: torch.Tensor,
        feature_idx: int,
        delta: float,
    ) -> torch.Tensor:
        """
        增强特定特征的激活
        
        Args:
            x: 输入激活 [batch, seq_len, d_model]
            feature_idx: 特征索引
            delta: 增强量
        
        Returns:
            修改后的输出
        """
        # 编码
        features = self.sae.encode(x)
        
        # 增强
        features[:, :, feature_idx] = features[:, :, feature_idx] + delta
        
        # 解码
        modified_x = self.sae.decode(features)
        
        # 继续前向传播
        output = self.model(modified_x)
        
        return output
    
    def suppress_feature(
        self,
        x: torch.Tensor,
        feature_idx: int,
        alpha: float = 0.0,
    ) -> torch.Tensor:
        """
        抑制特定特征的激活
        
        Args:
            x: 输入激活
            feature_idx: 特征索引
            alpha: 抑制系数 (0=完全抑制, 1=不变)
        
        Returns:
            修改后的输出
        """
        features = self.sae.encode(x)
        features[:, :, feature_idx] = features[:, :, feature_idx] * alpha
        modified_x = self.sae.decode(features)
        return self.model(modified_x)
    
    def set_feature(
        self,
        x: torch.Tensor,
        feature_idx: int,
        value: float,
    ) -> torch.Tensor:
        """设置特征的绝对值"""
        features = self.sae.encode(x)
        features[:, :, feature_idx] = value
        modified_x = self.sae.decode(features)
        return self.model(modified_x)
    
    def multi_feature_steering(
        self,
        x: torch.Tensor,
        modifications: dict[int, float],  # feature_idx -> delta/alpha
        steering_type: str = "enhance",
    ) -> torch.Tensor:
        """
        同时操控多个特征
        
        Args:
            x: 输入激活
            modifications: {特征索引: 修改量}
            steering_type: "enhance" 或 "suppress"
        
        Returns:
            修改后的输出
        """
        features = self.sae.encode(x)
        
        for feat_idx, delta in modifications.items():
            if steering_type == "enhance":
                features[:, :, feat_idx] = features[:, :, feat_idx] + delta
            elif steering_type == "suppress":
                features[:, :, feat_idx] = features[:, :, feat_idx] * (1 - delta)
            elif steering_type == "set":
                features[:, :, feat_idx] = delta
        
        modified_x = self.sae.decode(features)
        return self.model(modified_x)

2.2 方向操控

class DirectionalSteering:
    """基于特征解码方向的操控"""
    
    def __init__(self, sae, model):
        self.sae = sae
        self.model = model
        self.device = next(sae.parameters()).device
    
    def compute_feature_direction(
        self,
        feature_idx: int,
        n_samples: int = 100,
    ) -> torch.Tensor:
        """
        计算特征的解码方向向量
        
        这个方向表示激活该特征时激活空间的变化方向
        """
        # 获取解码器权重列
        decoder_weights = self.sae.decoder.weight.data  # [d_model, n_features]
        direction = decoder_weights[:, feature_idx]
        
        # 归一化
        direction = direction / (direction.norm() + 1e-8)
        
        return direction
    
    def steering_by_direction(
        self,
        x: torch.Tensor,
        feature_indices: list[int],
        strengths: list[float],
        alpha: float = 0.5,
    ) -> torch.Tensor:
        """
        使用特征方向进行操控
        
        通过调整激活来增强特定方向
        
        Args:
            x: 输入激活
            feature_indices: 特征索引列表
            strengths: 每个特征的强度
            alpha: 混合系数 (0=原始激活, 1=完全替换)
        
        Returns:
            修改后的激活
        """
        # 计算组合方向
        combined_direction = torch.zeros(x.shape[-1], device=self.device)
        
        for feat_idx, strength in zip(feature_indices, strengths):
            direction = self.compute_feature_direction(feat_idx)
            combined_direction = combined_direction + strength * direction
        
        # 归一化
        combined_direction = combined_direction / (
            combined_direction.norm() + 1e-8
        )
        
        # 计算当前激活在方向上的投影
        current_projection = (x * combined_direction).sum(dim=-1, keepdim=True)
        
        # 计算目标投影
        target_projection = current_projection + sum(strengths)
        
        # 调整激活
        projection_diff = target_projection - current_projection
        steering_vector = projection_diff * combined_direction
        
        # 混合
        modified_x = (1 - alpha) * x + alpha * (x + steering_vector)
        
        return modified_x
    
    def directional_suppress(
        self,
        x: torch.Tensor,
        feature_idx: int,
        strength: float = 1.0,
    ) -> torch.Tensor:
        """
        沿特征方向抑制激活
        
        找到激活空间中与该特征方向相反的点
        """
        direction = self.compute_feature_direction(feature_idx)
        
        # 计算激活到原点的投影
        current_projection = (x * direction).sum(dim=-1, keepdim=True)
        
        # 沿反方向移动
        suppression_vector = -strength * direction * current_projection
        
        modified_x = x + suppression_vector
        
        return modified_x

2.3 自适应操控

class AdaptiveSteering:
    """自适应特征操控"""
    
    def __init__(self, sae, model):
        self.sae = sae
        self.model = model
        self.device = next(sae.parameters()).device
        
        # 预计算特征统计
        self.feature_stats = {}
    
    def compute_feature_statistics(
        self,
        dataset: torch.Tensor,
    ):
        """计算每个特征的统计信息"""
        all_features = []
        
        with torch.no_grad():
            for x in dataset:
                features = self.sae.encode(x)
                all_features.append(features)
        
        all_features = torch.cat(all_features, dim=0)
        
        # 计算统计
        self.feature_stats = {
            "mean": all_features.mean(dim=(0, 1)),
            "std": all_features.std(dim=(0, 1)),
            "median": all_features.median(dim=(0, 1)).values,
            "q25": all_features.quantile(0.25, dim=(0, 1)),
            "q75": all_features.quantile(0.75, dim=(0, 1)),
        }
    
    def adaptive_enhance(
        self,
        x: torch.Tensor,
        feature_idx: int,
        target_percentile: float = 0.9,
    ) -> torch.Tensor:
        """
        自适应增强:将特征提升到目标分位数
        
        Args:
            x: 输入激活
            feature_idx: 特征索引
            target_percentile: 目标分位数 (0-1)
        """
        features = self.sae.encode(x)
        
        # 计算目标值
        target_value = self.feature_stats["q75"][feature_idx] * (
            1 + target_percentile
        )
        
        # 只增强高于当前值的部分
        current_values = features[:, :, feature_idx]
        mask = current_values > target_value
        
        features[:, :, feature_idx] = torch.where(
            mask,
            target_value * (1 + 0.5),
            features[:, :, feature_idx]
        )
        
        modified_x = self.sae.decode(features)
        return self.model(modified_x)
    
    def clamp_feature(
        self,
        x: torch.Tensor,
        feature_idx: int,
        min_value: Optional[float] = None,
        max_value: Optional[float] = None,
    ) -> torch.Tensor:
        """
        限制特征值的范围
        """
        features = self.sae.encode(x)
        
        if min_value is not None:
            features[:, :, feature_idx] = torch.clamp_min(
                features[:, :, feature_idx],
                min_value
            )
        
        if max_value is not None:
            features[:, :, feature_idx] = torch.clamp_max(
                features[:, :, feature_idx],
                max_value
            )
        
        modified_x = self.sae.decode(features)
        return self.model(modified_x)

3. 应用场景

3.1 情感控制

def control_sentiment(
    steering: SAEFeatureSteering,
    model,
    tokenizer,
    text: str,
    target_sentiment: str,  # "positive" or "negative"
    feature_idx: int,
    strength: float = 2.0,
) -> str:
    """
    使用SAE特征控制生成文本的情感
    """
    inputs = tokenizer(text, return_tensors="pt").to(steering.device)
    
    # 获取隐藏状态
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
    
    if target_sentiment == "positive":
        # 增强正面情感特征
        modified_hidden = steering.enhance_feature(
            hidden_states, feature_idx, strength
        )
    else:
        # 抑制正面情感特征(增强负面)
        modified_hidden = steering.suppress_feature(
            hidden_states, feature_idx, alpha=0.5
        )
    
    # 使用修改的隐藏状态生成
    with torch.no_grad():
        modified_outputs = model(
            inputs["input_ids"],
            hidden_states=(modified_hidden,),
        )
        generated_ids = modified_outputs.logits.argmax(-1)
    
    return tokenizer.decode(generated_ids[0])

3.2 知识编辑

def edit_knowledge(
    steering: SAEFeatureSteering,
    model,
    tokenizer,
    subject: str,
    current_knowledge_idx: int,
    new_knowledge_idx: int,
    context: str,
) -> str:
    """
    使用SAE特征编辑模型知识
    
    例如:将"巴黎是法国的首都"改为关于"罗马是意大利首都"的知识
    """
    inputs = tokenizer(context + " " + subject, return_tensors="pt").to(
        steering.device
    )
    
    # 获取隐藏状态
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]
    
    # 抑制旧知识特征,增强新知识特征
    modifications = {
        current_knowledge_idx: -2.0,  # 抑制旧知识
        new_knowledge_idx: 2.0,        # 增强新知识
    }
    
    modified_hidden = steering.multi_feature_steering(
        hidden, modifications, steering_type="enhance"
    )
    
    # 生成
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            hidden_states=(modified_hidden,),
            max_new_tokens=50,
        )
    
    return tokenizer.decode(outputs[0])

3.3 安全干预

class SafetyIntervention:
    """使用SAE进行安全干预"""
    
    def __init__(
        self,
        sae,
        model,
        safety_feature_indices: list[int],
        threshold: float = 0.5,
    ):
        self.steering = SAEFeatureSteering(sae, model)
        self.safety_indices = safety_feature_indices
        self.threshold = threshold
    
    def check_and_intervene(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor, bool]:
        """
        检查并干预不安全内容
        
        Returns:
            (修改后的输出, 是否进行了干预)
        """
        # 编码
        features = self.steering.sae.encode(x)
        
        # 检查安全特征
        safety_activations = features[:, :, self.safety_indices].max(dim=-1).values
        
        if safety_activations.max() > self.threshold:
            # 检测到不安全内容,进行干预
            # 抑制所有安全相关特征
            for idx in self.safety_indices:
                features[:, :, idx] = 0
            
            modified_x = self.steering.sae.decode(features)
            intervened = True
        else:
            modified_x = x
            intervened = False
        
        return modified_x, intervened
    
    def get_safety_score(
        self,
        x: torch.Tensor,
    ) -> dict:
        """
        获取内容的安全评分
        """
        features = self.steering.sae.encode(x)
        
        scores = {}
        for idx in self.safety_indices:
            feature_name = f"feature_{idx}"
            scores[feature_name] = features[:, :, idx].max().item()
        
        scores["overall"] = max(scores.values()) if scores else 0.0
        scores["is_safe"] = scores["overall"] < self.threshold
        
        return scores

3.4 风格转换

def convert_style(
    steering: SAEFeatureSteering,
    model,
    tokenizer,
    text: str,
    source_style_idx: int,
    target_style_idx: int,
) -> str:
    """
    风格转换:例如从正式转换为非正式
    """
    inputs = tokenizer(text, return_tensors="pt").to(steering.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]
    
    # 抑制源风格特征,增强目标风格特征
    modifications = {
        source_style_idx: -1.5,
        target_style_idx: 1.5,
    }
    
    modified_hidden = steering.multi_feature_steering(
        hidden, modifications
    )
    
    outputs = model.generate(
        inputs["input_ids"],
        hidden_states=(modified_hidden,),
    )
    
    return tokenizer.decode(outputs[0])

4. 操控效果评估

4.1 评估指标

def evaluate_steering_effectiveness(
    steering: SAEFeatureSteering,
    model,
    test_cases: list[dict],
    target_feature_idx: int,
) -> dict:
    """
    评估操控效果
    
    测试用例格式: {
        "input": str,
        "original_output": str,
        "expected_change": str,
    }
    """
    results = {
        "successful_changes": 0,
        "total_cases": len(test_cases),
        "behavioral_changes": [],
        "feature_activation_changes": [],
    }
    
    for case in test_cases:
        # 原始输出
        original_out = get_model_output(model, case["input"])
        
        # 操控后的输出
        steered_out = steering.enhance_feature(
            get_hidden_states(model, case["input"]),
            target_feature_idx,
            delta=2.0,
        )
        
        # 检查变化
        changed = original_out != steered_out
        expected_change_achieved = (
            case["expected_change"] in steered_out
        )
        
        results["behavioral_changes"].append({
            "input": case["input"],
            "original": original_out,
            "steered": steered_out,
            "changed": changed,
            "expected_achieved": expected_change_achieved,
        })
        
        if expected_change_achieved:
            results["successful_changes"] += 1
    
    results["success_rate"] = (
        results["successful_changes"] / results["total_cases"]
    )
    
    return results

4.2 副作用检测

def detect_steering_side_effects(
    steering: SAEFeatureSteering,
    model,
    neutral_inputs: list[str],
    target_feature_idx: int,
    delta: float = 2.0,
) -> dict:
    """
    检测操控的副作用
    
    使用中性输入确保操控不会意外改变不相关的内容
    """
    side_effects = []
    
    for inp in neutral_inputs:
        # 原始输出
        original = get_model_output(model, inp)
        
        # 操控后输出
        hidden = get_hidden_states(model, inp)
        steered_hidden = steering.enhance_feature(
            hidden, target_feature_idx, delta
        )
        steered = get_model_output_from_hidden(model, steered_hidden, inp)
        
        # 计算相似度
        similarity = difflib.SequenceMatcher(
            None, original, steered
        ).ratio()
        
        if similarity < 0.9:  # 如果变化太大,认为有副作用
            side_effects.append({
                "input": inp,
                "original": original,
                "steered": steered,
                "similarity": similarity,
            })
    
    return {
        "n_side_effects": len(side_effects),
        "side_effect_rate": len(side_effects) / len(neutral_inputs),
        "side_effect_examples": side_effects,
    }

5. 最佳实践

5.1 操控强度选择

def find_optimal_steering_strength(
    steering: SAEFeatureSteering,
    model,
    test_input: str,
    target_feature: int,
    target_behavior: str,
    strength_range: tuple = (0.1, 5.0),
) -> float:
    """
    找到最小的有效操控强度
    
    使用二分搜索找到能产生目标效果的最小强度
    """
    low, high = strength_range
    best_strength = high
    best_achieved = False
    
    for _ in range(10):  # 最多10次迭代
        mid = (low + high) / 2
        
        # 测试当前强度
        hidden = get_hidden_states(model, test_input)
        steered_hidden = steering.enhance_feature(
            hidden, target_feature, mid
        )
        output = get_model_output_from_hidden(model, steered_hidden, test_input)
        
        achieved = target_behavior in output
        
        if achieved:
            best_strength = mid
            best_achieved = True
            high = mid  # 尝试更小的强度
        else:
            low = mid  # 需要更大的强度
    
    return best_strength if best_achieved else None

5.2 组合操控策略

class CombinationSteering:
    """组合特征操控"""
    
    def __init__(self, sae, model):
        self.sae = sae
        self.model = model
        self.steering = SAEFeatureSteering(sae, model)
    
    def multi_objective_steering(
        self,
        x: torch.Tensor,
        objectives: list[dict],  # [{feature_idx, strength, goal}, ...]
    ) -> torch.Tensor:
        """
        多目标操控:同时满足多个目标
        
        objectives: [
            {"feature_idx": 123, "strength": 1.0, "goal": "enhance"},
            {"feature_idx": 456, "strength": 0.5, "goal": "suppress"},
        ]
        """
        features = self.sae.encode(x)
        
        for obj in objectives:
            idx = obj["feature_idx"]
            strength = obj["strength"]
            goal = obj["goal"]
            
            if goal == "enhance":
                features[:, :, idx] = features[:, :, idx] + strength
            elif goal == "suppress":
                features[:, :, idx] = features[:, :, idx] * (1 - strength)
            elif goal == "set":
                features[:, :, idx] = strength
        
        modified_x = self.sae.decode(features)
        return modified_x
    
    def gradual_steering(
        self,
        x: torch.Tensor,
        modifications: dict,
        n_steps: int = 5,
    ) -> list[torch.Tensor]:
        """
        渐进式操控:逐步增加操控强度
        
        返回每一步的修改结果
        """
        results = []
        
        for step in range(1, n_steps + 1):
            scale = step / n_steps
            
            scaled_mods = {
                idx: strength * scale
                for idx, strength in modifications.items()
            }
            
            modified = self.multi_objective_steering(x, scaled_mods)
            results.append(modified)
        
        return results

6. 局限性与注意事项

6.1 已知局限性

局限性描述影响
特征干扰修改一个特征可能影响其他特征操控不精确
上下文依赖同一特征在不同上下文效果不同效果不稳定
解码误差SAE解码不是完全可逆的信息损失
层级差异不同层特征含义不同需要选择合适层

6.2 安全考虑

  1. 避免过度操控:过强的操控可能导致模型行为异常
  2. 验证输出:始终检查操控后的输出是否合理
  3. 记录操控:保存操控参数以便复现
  4. 测试边界:测试操控在极端情况下的表现

6.3 最佳使用场景

场景推荐程度原因
情感/风格调整★★★★★效果稳定
知识增强★★★★☆可能有幻觉
安全过滤★★★☆☆可能绕过
精确事实编辑★★☆☆☆效果有限

7. 参考

论文

  1. Bricken et al. “Towards Monosemanticity” (2023)
  2. Templeton et al. “Gemma Scope” (2024)
  3. “Feature Steering for LLMs” (various works)

工具