LLM中的因果公平性

1. 引言

大型语言模型(LLMs)在训练过程中会从互联网数据中学习,这可能导致模型继承并放大社会偏见。因果视角为理解和解决LLM中的偏见问题提供了独特的工具。1

1.1 LLM中的偏见问题

LLM偏见来源与传播:

训练数据 (互联网文本)
        ↓
    偏见模式 (种族、性别、职业等)
        ↓
    模型学习
        ↓
    输出偏见 (文本生成、分类、问答)

1.2 为什么需要因果方法

方法类型能力局限
统计方法测量关联强度无法区分因果vs混杂
因果方法识别因果路径需要因果假设
混合方法两者结合计算复杂度更高

因果方法可以回答:

  • 偏见的真正来源是什么?
  • 改变训练数据分布能否减少偏见?
  • 模型内部如何编码偏见信息?

2. LLM偏见的因果建模

2.1 因果图视角

class LLM_Bias_Causal_Graph:
    """
    LLM偏见的因果图建模
    
    节点:
    - S: 社会群体 (敏感属性)
    - C: 上下文/提示
    - E: 实体提及
    - P: 模型内部表示
    - Y: 模型输出
    """
    
    def build_graph(self, scenario='occupation'):
        if scenario == 'occupation':
            # 职业偏见场景
            nodes = {
                'S': 'social_group',  # 敏感属性
                'C': 'context',       # 输入上下文
                'E': 'entity',        # 实体提及
                'P': 'internal_repr', # 内部表示
                'Y': 'output'        # 输出
            }
            
            edges = [
                # 偏见路径
                ('S', 'E'),  # 群体→实体(社会关联)
                ('E', 'Y'),  # 实体→输出(直接偏见)
                
                # 合法路径
                ('C', 'E'),  # 上下文→实体
                ('C', 'Y'),  # 上下文→输出
                
                # 混杂路径
                ('S', 'C'),  # 社会→上下文(选择性曝光)
            ]
        
        return nodes, edges

2.2 偏见类型分类

偏见类型因果路径示例
直接歧视直接使用群体信息
间接歧视通过刻板印象实体
历史偏见数据中的历史偏见
代理偏见通过代理变量

2.3 因果偏见的数学形式化

总体输出差异

因果分解


3. 因果偏见测试

3.1 Counterfactual Fairness Testing

核心思想:通过反事实查询测试模型对敏感属性的依赖程度。

class CausalBiasTester:
    def __init__(self, llm):
        self.llm = llm
    
    def counterfactual_fairness_test(self, prompt, sensitive_attr, 
                                     attr_values):
        """
        反事实公平性测试
        
        比较模型在不同敏感属性值下的输出
        """
        results = {}
        
        for value in attr_values:
            # 构建反事实提示
            counterfactual_prompt = self._swap_attribute(
                prompt, sensitive_attr, value
            )
            
            # 获取模型输出
            output = self.llm.generate(counterfactual_prompt)
            results[value] = output
        
        # 计算公平性得分
        fairness_score = self._compute_fairness_score(results)
        
        return {
            'outputs': results,
            'fairness_score': fairness_score
        }
    
    def _swap_attribute(self, prompt, attr, new_value):
        """替换提示中的敏感属性"""
        # 简单实现:替换特定词汇
        mappings = {
            'male': {'he': 'she', 'man': 'woman', 'his': 'her'},
            'female': {'she': 'he', 'woman': 'man', 'her': 'his'},
            'white': {'white': 'Black', 'European': 'African'},
            'Black': {'Black': 'white', 'African': 'European'}
        }
        
        result = prompt
        for old, new in mappings.get(attr, {}).items():
            result = result.replace(old, new)
        
        return result

3.2 因果归因测试

class CausalAttributionTest:
    """
    因果归因测试
    识别哪些因素导致输出偏见
    """
    
    def test_causal_attribution(self, prompt, output):
        """
        使用因果归因方法识别偏见来源
        """
        # 1. 定义因果图
        causal_graph = self._build_causal_graph(prompt)
        
        # 2. 识别关键路径
        key_paths = self._identify_key_paths(
            causal_graph, 
            target='output_bias'
        )
        
        # 3. 估计因果效应
        causal_effects = self._estimate_causal_effects(
            causal_graph, 
            key_paths
        )
        
        return {
            'causal_graph': causal_graph,
            'key_paths': key_paths,
            'causal_effects': causal_effects
        }
    
    def _build_causal_graph(self, prompt):
        """从提示构建因果图"""
        # 提取关键实体和关系
        entities = self._extract_entities(prompt)
        relations = self._extract_relations(prompt)
        
        # 构建图
        graph = nx.DiGraph()
        graph.add_nodes_from(entities)
        graph.add_edges_from(relations)
        
        return graph

3.3 CausalBias框架

NAACL 2025: Causally Testing Gender Bias in LLMs2

核心贡献

  1. 形式化因果偏见测量的定义
  2. 区分
    • 统计关联偏见(
    • 因果偏见(

测试方法

class CausalBiasFramework:
    """
    因果偏见测试框架
    
    基于反事实推理:
    P(Y_{A←a} = 1) vs P(Y_{A←a'} = 1)
    """
    
    def measure_causal_bias(self, model, test_pairs):
        """
        测量因果偏见
        
        Args:
            model: LLM模型
            test_pairs: 反事实测试对 [(prompt_a, prompt_a'), label]
        
        Returns:
            causal_bias_score: 因果偏见得分
        """
        biases = []
        
        for (prompt_a, prompt_a_prime), label in test_pairs:
            # 获取两个版本的输出
            prob_a = self._get_output_prob(model, prompt_a, label)
            prob_a_prime = self._get_output_prob(model, prompt_a_prime, label)
            
            # 因果偏见 = 反事实概率差异
            causal_diff = abs(prob_a - prob_a_prime)
            biases.append(causal_diff)
        
        return np.mean(biases)
    
    def _get_output_prob(self, model, prompt, label):
        """获取模型输出特定标签的概率"""
        # 使用模型API获取token概率
        tokens = model.tokenizer.encode(label)
        logits = model(prompt).logits
        
        # 计算标签概率
        label_prob = torch.softmax(logits[-1], dim=-1)[tokens].item()
        
        return label_prob

4. 因果去偏方法

4.1 因果Prompting

核心思想:通过精心设计的prompt引导模型进行因果推理,减少偏见。

class CausalPrompting:
    """
    因果Prompting去偏方法
    
    通过明确引导因果推理来减少偏见
    """
    
    def generate_fair_prompt(self, original_prompt, context_type):
        """
        生成公平性导向的prompt
        """
        causal_templates = {
            'hiring': """
            分析以下招聘决策时,请注意:
            1. 评估每个候选人的技能和经验,而非其背景
            2. 假设类似背景的候选人应有相似的能力
            3. 忽略与工作无关的人口统计特征
            
            候选人信息:{original_prompt}
            
            决策建议:
            """,
            
            'loan': """
            在评估贷款申请时:
            1. 基于申请人的信用历史和收入
            2. 不考虑种族、性别等受保护特征
            3. 假设类似财务状况的申请人应有相似结果
            
            申请信息:{original_prompt}
            
            审批建议:
            """,
            
            'housing': """
            在评估住房申请时:
            1. 基于申请人的资格和信用
            2. 忽略与住房需求无关的特征
            3. 假设类似需求的申请人应获同等机会
            
            申请信息:{original_prompt}
            
            决定:
            """
        }
        
        return causal_templates.get(context_type, original_prompt)

4.2 因果干预微调

class CausalInterventionFineTuning:
    """
    因果干预微调
    
    在微调过程中引入因果干预信号
    """
    
    def __init__(self, model, causal_graph):
        self.model = model
        self.causal_graph = causal_graph
        self.unfair_paths = self._identify_unfair_paths()
    
    def train(self, train_data, n_epochs=3, lr=1e-5):
        """
        因果干预训练
        
        目标:使模型对沿不公平路径的变化不敏感
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        
        for epoch in range(n_epochs):
            for batch in train_data:
                # 标准语言建模损失
                lm_loss = self._compute_lm_loss(batch)
                
                # 因果干预损失
                intervention_loss = self._compute_intervention_loss(batch)
                
                # 总损失
                total_loss = lm_loss + self.alpha * intervention_loss
                
                total_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
    
    def _compute_intervention_loss(self, batch):
        """
        计算因果干预损失
        
        L_intervention = Σ ||f(X) - f(X_{A←a'})||²
        沿不公平路径干预后,输出应保持不变
        """
        X = batch['input']
        A = batch['sensitive_attr']
        
        # 获取原始输出
        output_orig = self.model(X)
        
        # 对敏感属性进行干预
        intervention_loss = 0
        for path in self.unfair_paths:
            X_intervened = self._intervene(X, path, A)
            output_intervened = self.model(X_intervened)
            
            # 鼓励输出不变
            intervention_loss += F.mse_loss(output_orig, output_intervened)
        
        return intervention_loss

4.3 因果对比学习

class CausalContrastiveLearning:
    """
    因果对比学习
    
    学习不依赖于敏感属性的公平表示
    """
    
    def __init__(self, encoder, projection_dim=128):
        self.encoder = encoder
        self.projector = nn.Linear(encoder.hidden_dim, projection_dim)
    
    def contrastive_loss(self, z1, z2, temperature=0.1):
        """
        因果对比损失
        
        正样本对:同一实体的不同公平视角
        负样本对:不同实体的表示
        """
        # 投影
        h1 = self.projector(z1)
        h2 = self.projector(z2)
        
        # L2归一化
        h1 = F.normalize(h1, dim=-1)
        h2 = F.normalize(h2, dim=-1)
        
        # 计算相似度
        sim = (h1 @ h2.T) / temperature
        
        # InfoNCE损失
        labels = torch.arange(len(h1)).to(h1.device)
        loss = F.cross_entropy(sim, labels)
        
        return loss
    
    def train_step(self, batch):
        """
        训练步骤
        """
        # 编码
        z1 = self.encoder(batch['text_a'])
        z2 = self.encoder(batch['text_b'])
        
        # 因果对比损失
        loss = self.contrastive_loss(z1, z2)
        
        # 反事实一致性损失
        cf_loss = self._counterfactual_consistency(batch)
        
        return loss + self.beta * cf_loss

5. 测试时公平性

5.1 Test-Time Fairness

Test-Time Fairness(Cotta & Maddison, 2024)3

核心思想:在推理时控制模型输出,而不修改模型权重。

class TestTimeFairness:
    """
    测试时公平性控制
    """
    
    def __init__(self, model, causal_graph):
        self.model = model
        self.causal_graph = causal_graph
    
    def generate_fair(self, prompt, target_group=None, 
                      fairness_constraint='counterfactual'):
        """
        生成公平输出
        
        通过解码策略实现测试时公平性
        """
        if fairness_constraint == 'counterfactual':
            return self._counterfactual_decoding(prompt, target_group)
        elif fairness_constraint == 'group_parity':
            return self._group_parity_decoding(prompt)
        elif fairness_constraint == 'calibration':
            return self._calibrated_decoding(prompt)
    
    def _counterfactual_decoding(self, prompt, target_group):
        """
        反事实解码
        
        对多个反事实版本采样,选择最公平的输出
        """
        # 生成反事实提示变体
        counterfactual_prompts = self._generate_counterfactuals(
            prompt, 
            sensitive_attr=target_group
        )
        
        # 获取所有输出
        outputs = []
        for cf_prompt in counterfactual_prompts:
            output = self.model.generate(cf_prompt)
            outputs.append(output)
        
        # 选择最一致的输出(公平性最高的)
        fair_output = self._select_most_consistent(outputs)
        
        return fair_output
    
    def _generate_counterfactuals(self, prompt, sensitive_attr):
        """
        生成反事实提示变体
        """
        counterfactuals = [prompt]
        
        # 性别变体
        gender_swap = {
            'he': 'she', 'she': 'he',
            'him': 'her', 'her': 'him',
            'his': 'her', 'woman': 'man', 'man': 'woman'
        }
        
        cf_prompt = prompt
        for old, new in gender_swap.items():
            cf_prompt = cf_prompt.replace(old, new)
        counterfactuals.append(cf_prompt)
        
        return counterfactuals

5.2 Causal Decoding

class CausalDecoding:
    """
    因果解码:引导生成远离偏见
    """
    
    def __init__(self, model, bias_direction):
        self.model = model
        # 学习到的偏见方向(高维空间)
        self.bias_direction = bias_direction
    
    def decode_with_bias_correction(self, prompt, max_length=100):
        """
        带偏见校正的解码
        
        在每一步的token概率上减去偏见方向的投影
        """
        input_ids = self.model.tokenizer.encode(prompt, 
                                               return_tensors='pt')
        
        generated = input_ids
        
        for _ in range(max_length):
            # 获取logits
            outputs = self.model(generated)
            logits = outputs.logits[:, -1, :]
            
            # 减去偏见方向
            bias_score = logits @ self.bias_direction
            corrected_logits = logits - self.alpha * bias_score.unsqueeze(0)
            
            # 采样
            probs = F.softmax(corrected_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            generated = torch.cat([generated, next_token], dim=-1)
            
            if next_token.item() == self.model.tokenizer.eos_token_id:
                break
        
        return self.model.tokenizer.decode(generated[0])

6. 实体偏见与因果分析

6.1 实体偏见的因果视角

论文:“A Causal View of Entity Bias in LLMs”(Wang et al., 2023)4

核心洞察:将实体偏见建模为因果图中的特定路径。

class EntityBiasCausalModel:
    """
    实体偏见的因果模型
    
    分析实体如何影响模型输出的偏见
    """
    
    def __init__(self):
        self.bias_types = {
            'stereotype': '社会刻板印象',
            'representation': '群体代表性问题',
            'association': '敏感属性与实体的关联'
        }
    
    def analyze_entity_bias(self, entity_mentions, output):
        """
        分析实体偏见
        
        识别哪些实体的提及导致偏见输出
        """
        # 构建实体-输出因果图
        causal_graph = self._build_entity_output_graph(
            entity_mentions, 
            output
        )
        
        # 识别偏见路径
        bias_paths = []
        for entity in entity_mentions:
            for output_component in output:
                if self._is_bias_path(causal_graph, entity, output_component):
                    bias_paths.append((entity, output_component))
        
        return bias_paths
    
    def _is_bias_path(self, graph, source, target):
        """判断是否存在偏见路径"""
        # 检查是否存在从敏感属性到该实体的路径
        # ...
        pass

6.2 参数化偏见 vs 分布偏见

论文洞察

偏见类型描述因果解释
参数化偏见模型权重中的固有偏见训练过程中学习到的因果关联
分布偏见数据分布导致的偏见输入空间中敏感属性与目标的伪相关

7. 评估与基准

7.1 因果公平性评估指标

class CausalFairnessMetrics:
    """
    因果公平性评估指标
    """
    
    def counterfactual_fairness_score(self, model, test_set):
        """
        反事实公平性得分
        
        CF Score = 1 - E[|P(Y|A=a) - P(Y|A=a')|]
        """
        fairness_scores = []
        
        for item in test_set:
            # 原始和反事实版本
            prob_orig = self._get_label_prob(model, item['prompt'])
            prob_cf = self._get_label_prob(model, item['counterfactual'])
            
            score = 1 - abs(prob_orig - prob_cf)
            fairness_scores.append(score)
        
        return np.mean(fairness_scores)
    
    def causal_disparity_measure(self, model, test_set, causal_graph):
        """
        因果差异度量
        
        区分直接效应和间接效应
        """
        direct_effects = []
        indirect_effects = []
        
        for item in test_set:
            # 使用因果分解估计直接效应和间接效应
            de, ie = self._causal_decomposition(
                model, item, causal_graph
            )
            direct_effects.append(de)
            indirect_effects.append(ie)
        
        return {
            'avg_direct_effect': np.mean(direct_effects),
            'avg_indirect_effect': np.mean(indirect_effects),
            'total_disparity': np.mean(direct_effects) + np.mean(indirect_effects)
        }

7.2 评估基准

基准描述评估维度
BOLD文本生成偏见多个社会维度
RealToxicityPrompts毒性检测毒性偏见
CBBQ因果偏见测试因果公平性
CausalBiasBench因果偏见综合评估多维度因果分析

8. 最新研究进展

8.1 FairPFN (ICML 2025)

Tabular Foundation Model for Causal Fairness

核心思想

  • 将Transformer应用于表格数据的因果公平性
  • 实现可扩展的反事实公平分析

8.2 Causal Logistic Bandits (ICML 2025)

将公平性约束扩展到顺序决策场景

  • 在Bandit框架下建模公平性
  • 处理探索-利用权衡中的公平性约束

8.3 Prompting Fairness

Integration Causality to Debias LLMs

class FairPromptGenerator:
    """
    公平性导向的Prompt生成器
    """
    
    def generate(self, task_description, fairness_principles):
        """
        基于公平性原则生成Prompt
        """
        system_prompt = f"""
        You are an AI assistant focused on fairness and non-discrimination.
        
        Key principles to follow:
        {fairness_principles}
        
        When responding:
        1. Consider multiple perspectives
        2. Avoid stereotypes and assumptions
        3. Focus on relevant qualifications
        4. Apply consistent standards
        """
        
        return system_prompt + "\n\n" + task_description

9. 实践指南

9.1 LLM公平性测试清单

llm_fairness_testing_checklist = """
LLM公平性测试清单:
 
□ 因果偏见识别
  □ 定义敏感属性和目标变量
  □ 构建因果图
  □ 识别不公平路径
 
□ 反事实测试
  □ 生成反事实测试对
  □ 测量输出一致性
  □ 计算因果偏见得分
 
□ 偏见来源分析
  □ 识别关键偏见路径
  □ 估计因果效应大小
  □ 分析组内vs组间偏见
 
□ 去偏方法
  □ Prompt层面:因果Prompting
  □ 微调层面:因果干预训练
  □ 推理层面:测试时公平性
 
□ 评估验证
  □ 使用标准基准
  □ 测量公平性-实用性权衡
  □ 对比不同方法的效果

9.2 注意事项

  1. 因果假设的重要性:确保因果图假设合理
  2. 反事实的可行性:并非所有反事实都有意义
  3. 公平性的多维性:不同维度可能存在冲突的公平性目标
  4. 动态性:偏见可能随时间和上下文变化

10. 参考文献


相关主题

Footnotes

  1. Wu, A., et al. (2024). Causality for Large Language Models. arXiv.

  2. Chen, Y., et al. (2025). Causally Testing Gender Bias in LLMs. NAACL Findings.

  3. Cotta, L., & Maddison, C. (2024). Test-Time Fairness and Robustness in LLMs.

  4. Wang, F., et al. (2023). A Causal View of Entity Bias in LLMs.