基于SAE的特征操控
概述
Feature Steering(特征操控)是利用SAE学到的可解释特征来精确控制模型行为的技术。与传统的提示工程或微调不同,特征操控直接修改模型的内部激活,从而实现细粒度的行为控制。
核心思想:
- SAE将叠加的表示分解为稀疏、可解释的特征
- 修改特征激活可以直接影响模型的计算和行为
- 通过操控特定特征,可以实现情感控制、知识编辑、安全干预等
1. 特征操控基础
1.1 操控原理
原始前向传播:
输入 → 激活 x → [模型] → 输出
SAE编码后:
输入 → 激活 x → [SAE编码器] → 特征 f(x)
↓
操控特征 f'(x)
↓
输入 → 激活 x → [SAE解码器] → 修改激活 x' → [模型] → 输出
效果: 模型行为被精确修改
1.2 操控类型
| 操控类型 | 描述 | 典型应用 |
|---|---|---|
| 增强 | 增加特征激活值 | 增强某种能力 |
| 抑制 | 减少特征激活值 | 消除有害内容 |
| 替换 | 用另一个值替换 | 改变特征含义 |
| 组合 | 同时操控多个特征 | 复杂行为控制 |
1.3 数学形式
设原始特征为 ,操控后的特征为 :
增强:
抑制:
替换:
2. 操控方法
2.1 基础操控
import torch
import torch.nn.functional as F
from typing import Optional
class SAEFeatureSteering:
"""SAE特征操控器"""
def __init__(self, sae, model):
self.sae = sae
self.model = model
self.device = next(sae.parameters()).device
def enhance_feature(
self,
x: torch.Tensor,
feature_idx: int,
delta: float,
) -> torch.Tensor:
"""
增强特定特征的激活
Args:
x: 输入激活 [batch, seq_len, d_model]
feature_idx: 特征索引
delta: 增强量
Returns:
修改后的输出
"""
# 编码
features = self.sae.encode(x)
# 增强
features[:, :, feature_idx] = features[:, :, feature_idx] + delta
# 解码
modified_x = self.sae.decode(features)
# 继续前向传播
output = self.model(modified_x)
return output
def suppress_feature(
self,
x: torch.Tensor,
feature_idx: int,
alpha: float = 0.0,
) -> torch.Tensor:
"""
抑制特定特征的激活
Args:
x: 输入激活
feature_idx: 特征索引
alpha: 抑制系数 (0=完全抑制, 1=不变)
Returns:
修改后的输出
"""
features = self.sae.encode(x)
features[:, :, feature_idx] = features[:, :, feature_idx] * alpha
modified_x = self.sae.decode(features)
return self.model(modified_x)
def set_feature(
self,
x: torch.Tensor,
feature_idx: int,
value: float,
) -> torch.Tensor:
"""设置特征的绝对值"""
features = self.sae.encode(x)
features[:, :, feature_idx] = value
modified_x = self.sae.decode(features)
return self.model(modified_x)
def multi_feature_steering(
self,
x: torch.Tensor,
modifications: dict[int, float], # feature_idx -> delta/alpha
steering_type: str = "enhance",
) -> torch.Tensor:
"""
同时操控多个特征
Args:
x: 输入激活
modifications: {特征索引: 修改量}
steering_type: "enhance" 或 "suppress"
Returns:
修改后的输出
"""
features = self.sae.encode(x)
for feat_idx, delta in modifications.items():
if steering_type == "enhance":
features[:, :, feat_idx] = features[:, :, feat_idx] + delta
elif steering_type == "suppress":
features[:, :, feat_idx] = features[:, :, feat_idx] * (1 - delta)
elif steering_type == "set":
features[:, :, feat_idx] = delta
modified_x = self.sae.decode(features)
return self.model(modified_x)2.2 方向操控
class DirectionalSteering:
"""基于特征解码方向的操控"""
def __init__(self, sae, model):
self.sae = sae
self.model = model
self.device = next(sae.parameters()).device
def compute_feature_direction(
self,
feature_idx: int,
n_samples: int = 100,
) -> torch.Tensor:
"""
计算特征的解码方向向量
这个方向表示激活该特征时激活空间的变化方向
"""
# 获取解码器权重列
decoder_weights = self.sae.decoder.weight.data # [d_model, n_features]
direction = decoder_weights[:, feature_idx]
# 归一化
direction = direction / (direction.norm() + 1e-8)
return direction
def steering_by_direction(
self,
x: torch.Tensor,
feature_indices: list[int],
strengths: list[float],
alpha: float = 0.5,
) -> torch.Tensor:
"""
使用特征方向进行操控
通过调整激活来增强特定方向
Args:
x: 输入激活
feature_indices: 特征索引列表
strengths: 每个特征的强度
alpha: 混合系数 (0=原始激活, 1=完全替换)
Returns:
修改后的激活
"""
# 计算组合方向
combined_direction = torch.zeros(x.shape[-1], device=self.device)
for feat_idx, strength in zip(feature_indices, strengths):
direction = self.compute_feature_direction(feat_idx)
combined_direction = combined_direction + strength * direction
# 归一化
combined_direction = combined_direction / (
combined_direction.norm() + 1e-8
)
# 计算当前激活在方向上的投影
current_projection = (x * combined_direction).sum(dim=-1, keepdim=True)
# 计算目标投影
target_projection = current_projection + sum(strengths)
# 调整激活
projection_diff = target_projection - current_projection
steering_vector = projection_diff * combined_direction
# 混合
modified_x = (1 - alpha) * x + alpha * (x + steering_vector)
return modified_x
def directional_suppress(
self,
x: torch.Tensor,
feature_idx: int,
strength: float = 1.0,
) -> torch.Tensor:
"""
沿特征方向抑制激活
找到激活空间中与该特征方向相反的点
"""
direction = self.compute_feature_direction(feature_idx)
# 计算激活到原点的投影
current_projection = (x * direction).sum(dim=-1, keepdim=True)
# 沿反方向移动
suppression_vector = -strength * direction * current_projection
modified_x = x + suppression_vector
return modified_x2.3 自适应操控
class AdaptiveSteering:
"""自适应特征操控"""
def __init__(self, sae, model):
self.sae = sae
self.model = model
self.device = next(sae.parameters()).device
# 预计算特征统计
self.feature_stats = {}
def compute_feature_statistics(
self,
dataset: torch.Tensor,
):
"""计算每个特征的统计信息"""
all_features = []
with torch.no_grad():
for x in dataset:
features = self.sae.encode(x)
all_features.append(features)
all_features = torch.cat(all_features, dim=0)
# 计算统计
self.feature_stats = {
"mean": all_features.mean(dim=(0, 1)),
"std": all_features.std(dim=(0, 1)),
"median": all_features.median(dim=(0, 1)).values,
"q25": all_features.quantile(0.25, dim=(0, 1)),
"q75": all_features.quantile(0.75, dim=(0, 1)),
}
def adaptive_enhance(
self,
x: torch.Tensor,
feature_idx: int,
target_percentile: float = 0.9,
) -> torch.Tensor:
"""
自适应增强:将特征提升到目标分位数
Args:
x: 输入激活
feature_idx: 特征索引
target_percentile: 目标分位数 (0-1)
"""
features = self.sae.encode(x)
# 计算目标值
target_value = self.feature_stats["q75"][feature_idx] * (
1 + target_percentile
)
# 只增强高于当前值的部分
current_values = features[:, :, feature_idx]
mask = current_values > target_value
features[:, :, feature_idx] = torch.where(
mask,
target_value * (1 + 0.5),
features[:, :, feature_idx]
)
modified_x = self.sae.decode(features)
return self.model(modified_x)
def clamp_feature(
self,
x: torch.Tensor,
feature_idx: int,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
) -> torch.Tensor:
"""
限制特征值的范围
"""
features = self.sae.encode(x)
if min_value is not None:
features[:, :, feature_idx] = torch.clamp_min(
features[:, :, feature_idx],
min_value
)
if max_value is not None:
features[:, :, feature_idx] = torch.clamp_max(
features[:, :, feature_idx],
max_value
)
modified_x = self.sae.decode(features)
return self.model(modified_x)3. 应用场景
3.1 情感控制
def control_sentiment(
steering: SAEFeatureSteering,
model,
tokenizer,
text: str,
target_sentiment: str, # "positive" or "negative"
feature_idx: int,
strength: float = 2.0,
) -> str:
"""
使用SAE特征控制生成文本的情感
"""
inputs = tokenizer(text, return_tensors="pt").to(steering.device)
# 获取隐藏状态
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1]
if target_sentiment == "positive":
# 增强正面情感特征
modified_hidden = steering.enhance_feature(
hidden_states, feature_idx, strength
)
else:
# 抑制正面情感特征(增强负面)
modified_hidden = steering.suppress_feature(
hidden_states, feature_idx, alpha=0.5
)
# 使用修改的隐藏状态生成
with torch.no_grad():
modified_outputs = model(
inputs["input_ids"],
hidden_states=(modified_hidden,),
)
generated_ids = modified_outputs.logits.argmax(-1)
return tokenizer.decode(generated_ids[0])3.2 知识编辑
def edit_knowledge(
steering: SAEFeatureSteering,
model,
tokenizer,
subject: str,
current_knowledge_idx: int,
new_knowledge_idx: int,
context: str,
) -> str:
"""
使用SAE特征编辑模型知识
例如:将"巴黎是法国的首都"改为关于"罗马是意大利首都"的知识
"""
inputs = tokenizer(context + " " + subject, return_tensors="pt").to(
steering.device
)
# 获取隐藏状态
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden = outputs.hidden_states[-1]
# 抑制旧知识特征,增强新知识特征
modifications = {
current_knowledge_idx: -2.0, # 抑制旧知识
new_knowledge_idx: 2.0, # 增强新知识
}
modified_hidden = steering.multi_feature_steering(
hidden, modifications, steering_type="enhance"
)
# 生成
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
hidden_states=(modified_hidden,),
max_new_tokens=50,
)
return tokenizer.decode(outputs[0])3.3 安全干预
class SafetyIntervention:
"""使用SAE进行安全干预"""
def __init__(
self,
sae,
model,
safety_feature_indices: list[int],
threshold: float = 0.5,
):
self.steering = SAEFeatureSteering(sae, model)
self.safety_indices = safety_feature_indices
self.threshold = threshold
def check_and_intervene(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, bool]:
"""
检查并干预不安全内容
Returns:
(修改后的输出, 是否进行了干预)
"""
# 编码
features = self.steering.sae.encode(x)
# 检查安全特征
safety_activations = features[:, :, self.safety_indices].max(dim=-1).values
if safety_activations.max() > self.threshold:
# 检测到不安全内容,进行干预
# 抑制所有安全相关特征
for idx in self.safety_indices:
features[:, :, idx] = 0
modified_x = self.steering.sae.decode(features)
intervened = True
else:
modified_x = x
intervened = False
return modified_x, intervened
def get_safety_score(
self,
x: torch.Tensor,
) -> dict:
"""
获取内容的安全评分
"""
features = self.steering.sae.encode(x)
scores = {}
for idx in self.safety_indices:
feature_name = f"feature_{idx}"
scores[feature_name] = features[:, :, idx].max().item()
scores["overall"] = max(scores.values()) if scores else 0.0
scores["is_safe"] = scores["overall"] < self.threshold
return scores3.4 风格转换
def convert_style(
steering: SAEFeatureSteering,
model,
tokenizer,
text: str,
source_style_idx: int,
target_style_idx: int,
) -> str:
"""
风格转换:例如从正式转换为非正式
"""
inputs = tokenizer(text, return_tensors="pt").to(steering.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden = outputs.hidden_states[-1]
# 抑制源风格特征,增强目标风格特征
modifications = {
source_style_idx: -1.5,
target_style_idx: 1.5,
}
modified_hidden = steering.multi_feature_steering(
hidden, modifications
)
outputs = model.generate(
inputs["input_ids"],
hidden_states=(modified_hidden,),
)
return tokenizer.decode(outputs[0])4. 操控效果评估
4.1 评估指标
def evaluate_steering_effectiveness(
steering: SAEFeatureSteering,
model,
test_cases: list[dict],
target_feature_idx: int,
) -> dict:
"""
评估操控效果
测试用例格式: {
"input": str,
"original_output": str,
"expected_change": str,
}
"""
results = {
"successful_changes": 0,
"total_cases": len(test_cases),
"behavioral_changes": [],
"feature_activation_changes": [],
}
for case in test_cases:
# 原始输出
original_out = get_model_output(model, case["input"])
# 操控后的输出
steered_out = steering.enhance_feature(
get_hidden_states(model, case["input"]),
target_feature_idx,
delta=2.0,
)
# 检查变化
changed = original_out != steered_out
expected_change_achieved = (
case["expected_change"] in steered_out
)
results["behavioral_changes"].append({
"input": case["input"],
"original": original_out,
"steered": steered_out,
"changed": changed,
"expected_achieved": expected_change_achieved,
})
if expected_change_achieved:
results["successful_changes"] += 1
results["success_rate"] = (
results["successful_changes"] / results["total_cases"]
)
return results4.2 副作用检测
def detect_steering_side_effects(
steering: SAEFeatureSteering,
model,
neutral_inputs: list[str],
target_feature_idx: int,
delta: float = 2.0,
) -> dict:
"""
检测操控的副作用
使用中性输入确保操控不会意外改变不相关的内容
"""
side_effects = []
for inp in neutral_inputs:
# 原始输出
original = get_model_output(model, inp)
# 操控后输出
hidden = get_hidden_states(model, inp)
steered_hidden = steering.enhance_feature(
hidden, target_feature_idx, delta
)
steered = get_model_output_from_hidden(model, steered_hidden, inp)
# 计算相似度
similarity = difflib.SequenceMatcher(
None, original, steered
).ratio()
if similarity < 0.9: # 如果变化太大,认为有副作用
side_effects.append({
"input": inp,
"original": original,
"steered": steered,
"similarity": similarity,
})
return {
"n_side_effects": len(side_effects),
"side_effect_rate": len(side_effects) / len(neutral_inputs),
"side_effect_examples": side_effects,
}5. 最佳实践
5.1 操控强度选择
def find_optimal_steering_strength(
steering: SAEFeatureSteering,
model,
test_input: str,
target_feature: int,
target_behavior: str,
strength_range: tuple = (0.1, 5.0),
) -> float:
"""
找到最小的有效操控强度
使用二分搜索找到能产生目标效果的最小强度
"""
low, high = strength_range
best_strength = high
best_achieved = False
for _ in range(10): # 最多10次迭代
mid = (low + high) / 2
# 测试当前强度
hidden = get_hidden_states(model, test_input)
steered_hidden = steering.enhance_feature(
hidden, target_feature, mid
)
output = get_model_output_from_hidden(model, steered_hidden, test_input)
achieved = target_behavior in output
if achieved:
best_strength = mid
best_achieved = True
high = mid # 尝试更小的强度
else:
low = mid # 需要更大的强度
return best_strength if best_achieved else None5.2 组合操控策略
class CombinationSteering:
"""组合特征操控"""
def __init__(self, sae, model):
self.sae = sae
self.model = model
self.steering = SAEFeatureSteering(sae, model)
def multi_objective_steering(
self,
x: torch.Tensor,
objectives: list[dict], # [{feature_idx, strength, goal}, ...]
) -> torch.Tensor:
"""
多目标操控:同时满足多个目标
objectives: [
{"feature_idx": 123, "strength": 1.0, "goal": "enhance"},
{"feature_idx": 456, "strength": 0.5, "goal": "suppress"},
]
"""
features = self.sae.encode(x)
for obj in objectives:
idx = obj["feature_idx"]
strength = obj["strength"]
goal = obj["goal"]
if goal == "enhance":
features[:, :, idx] = features[:, :, idx] + strength
elif goal == "suppress":
features[:, :, idx] = features[:, :, idx] * (1 - strength)
elif goal == "set":
features[:, :, idx] = strength
modified_x = self.sae.decode(features)
return modified_x
def gradual_steering(
self,
x: torch.Tensor,
modifications: dict,
n_steps: int = 5,
) -> list[torch.Tensor]:
"""
渐进式操控:逐步增加操控强度
返回每一步的修改结果
"""
results = []
for step in range(1, n_steps + 1):
scale = step / n_steps
scaled_mods = {
idx: strength * scale
for idx, strength in modifications.items()
}
modified = self.multi_objective_steering(x, scaled_mods)
results.append(modified)
return results6. 局限性与注意事项
6.1 已知局限性
| 局限性 | 描述 | 影响 |
|---|---|---|
| 特征干扰 | 修改一个特征可能影响其他特征 | 操控不精确 |
| 上下文依赖 | 同一特征在不同上下文效果不同 | 效果不稳定 |
| 解码误差 | SAE解码不是完全可逆的 | 信息损失 |
| 层级差异 | 不同层特征含义不同 | 需要选择合适层 |
6.2 安全考虑
- 避免过度操控:过强的操控可能导致模型行为异常
- 验证输出:始终检查操控后的输出是否合理
- 记录操控:保存操控参数以便复现
- 测试边界:测试操控在极端情况下的表现
6.3 最佳使用场景
| 场景 | 推荐程度 | 原因 |
|---|---|---|
| 情感/风格调整 | ★★★★★ | 效果稳定 |
| 知识增强 | ★★★★☆ | 可能有幻觉 |
| 安全过滤 | ★★★☆☆ | 可能绕过 |
| 精确事实编辑 | ★★☆☆☆ | 效果有限 |
7. 参考
论文
- Bricken et al. “Towards Monosemanticity” (2023)
- Templeton et al. “Gemma Scope” (2024)
- “Feature Steering for LLMs” (various works)