概述

条件随机场(Conditional Random Field, CRF)是一种判别式概率图模型,特别适合序列标注任务。1

与生成式模型的对比

特性生成式(HMM, Naive Bayes)判别式(CRF)
建模方式$P(Y
特征利用有限任意特征
计算复杂度较低较高
灵活性受限

1. 线性链CRF基础

1.1 模型定义

给定输入序列 ,线性链CRF定义输出序列 的条件概率:

其中:

  • :特征函数
  • :特征权重
  • :配分函数(归一化常数)

1.2 特征函数

CRF使用两种类型的特征:

状态特征

例如: 表示”当前词是word时,标签是tag”。

转移特征

例如: 表示”前一标签是A时,当前标签是B”。

1.3 简化表示

使用打分函数统一表示:

其中势函数:


2. 推理算法

2.1 前向-后向算法

前向算法:计算

后向算法:计算

边缘概率

其中

2.2 Viterbi解码

寻找最可能的输出序列:

def viterbi(scores, transition_scores):
    """
    Viterbi解码
    
    参数:
        scores: (T, K) 每个位置各标签的打分
        transition_scores: (K, K) 转移矩阵
    
    返回:
        best_path: 最优路径
        best_score: 最优分数
    """
    T, K = scores.shape
    
    # 动态规划表
    dp = np.zeros((T, K))
    backptr = np.zeros((T, K), dtype=int)
    
    # 初始化
    dp[0] = scores[0]
    
    # 递推
    for t in range(1, T):
        for j in range(K):
            # 从所有前一个状态转移
            trans_scores = dp[t-1] + transition_scores[:, j]
            best_i = np.argmax(trans_scores)
            
            dp[t, j] = trans_scores[best_i] + scores[t, j]
            backptr[t, j] = best_i
    
    # 回溯
    best_path = np.zeros(T, dtype=int)
    best_path[T-1] = np.argmax(dp[T-1])
    
    for t in range(T-2, -1, -1):
        best_path[t] = backptr[t+1, best_path[t+1]]
    
    best_score = dp[T-1, best_path[T-1]]
    
    return best_path, best_score

2.3 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LinearCRF(nn.Module):
    """线性链CRF实现"""
    def __init__(self, num_tags, start_tag_idx=None, end_tag_idx=None):
        super().__init__()
        self.num_tags = num_tags
        self.start_tag_idx = start_tag_idx
        self.end_tag_idx = end_tag_idx
        
        # 转移矩阵: (num_tags+2, num_tags+2) 包含start和end
        if start_tag_idx is not None and end_tag_idx is not None:
            self.transitions = nn.Parameter(torch.randn(num_tags + 2, num_tags + 2))
        else:
            self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.uniform_(self.transitions, -0.1, 0.1)
    
    def forward(self, emissions, tags=None, mask=None):
        """
        emissions: (batch, seq_len, num_tags) 发射分数
        tags: (batch, seq_len) 标签序列
        mask: (batch, seq_len) 有效位置掩码
        """
        batch_size, seq_len = emissions.shape[:2]
        
        if tags is not None and mask is not None:
            # 计算负对数似然
            log_likelihood = self.neg_log_likelihood(emissions, tags, mask)
            return log_likelihood.mean()
        else:
            # 解码
            return self.viterbi_decode(emissions, mask)
    
    def neg_log_likelihood(self, emissions, tags, mask):
        """计算负对数似然"""
        batch_size, seq_len = emissions.shape[:2]
        
        # 添加start和end标签
        if self.start_tag_idx is not None:
            tags = torch.cat([
                torch.full((batch_size, 1), self.start_tag_idx, dtype=torch.long, device=tags.device),
                tags
            ], dim=1)
            mask = torch.cat([
                torch.ones(batch_size, 1, dtype=torch.bool, device=mask.device),
                mask
            ], dim=1)
        
        if self.end_tag_idx is not None:
            tags = torch.cat([
                tags,
                torch.full((batch_size, 1), self.end_tag_idx, dtype=torch.long, device=tags.device)
            ], dim=1)
        
        # 分数计算
        seq_len = tags.shape[1]
        score = torch.zeros(batch_size, device=emissions.device)
        
        # 起始分数
        if self.start_tag_idx is not None:
            score = score + self.transitions[self.start_tag_idx, tags[:, 0]]
        
        # 累加发射分数和转移分数
        for t in range(seq_len - 1):
            # 发射分数
            emit_score = emissions[:, t, tags[:, t + 1]]
            # 转移分数
            trans_score = self.transitions[tags[:, t], tags[:, t + 1]]
            # 累加
            score = score + emit_score + trans_score
        
        # 最后一个位置的发射分数
        if self.end_tag_idx is not None:
            score = score + emissions[:, -1, tags[:, -1]]
            score = score + self.transitions[tags[:, -1], self.end_tag_idx]
        else:
            score = score + emissions[:, -1, tags[:, -1]]
        
        # 计算配分函数
        partition = self._forward_algorithm(emissions)
        
        return partition - score
    
    def _forward_algorithm(self, emissions):
        """前向算法计算配分函数"""
        batch_size, seq_len, num_tags = emissions.shape
        
        # 初始化
        alpha = emissions[:, 0]
        
        if self.start_tag_idx is not None:
            alpha = alpha + self.transitions[self.start_tag_idx, :num_tags]
        
        # 递推
        for t in range(1, seq_len):
            # 发射分数
            emit_score = emissions[:, t].unsqueeze(1)  # (batch, 1, num_tags)
            
            # 转移分数
            trans_score = self.transitions[:num_tags, :num_tags].unsqueeze(0)  # (1, num_tags, num_tags)
            
            # 动态规划
            alpha_t = (alpha.unsqueeze(2) + trans_score + emit_score).logsumexp(dim=1)
            
            alpha = alpha_t
        
        # 加上end转移
        if self.end_tag_idx is not None:
            alpha = alpha + self.transitions[:num_tags, self.end_tag_idx]
        
        return alpha
    
    def viterbi_decode(self, emissions, mask=None):
        """Viterbi解码"""
        batch_size, seq_len, num_tags = emissions.shape
        
        # 初始化DP表
        viterbi = emissions[:, 0]
        backptr = torch.zeros_like(viterbi, dtype=torch.long)
        
        if self.start_tag_idx is not None:
            viterbi = viterbi + self.transitions[self.start_tag_idx, :num_tags]
        
        # 递推
        for t in range(1, seq_len):
            # (batch, num_tags, 1) + (batch, 1, num_tags) -> (batch, num_tags, num_tags)
            scores = viterbi.unsqueeze(2) + self.transitions[:num_tags, :num_tags].unsqueeze(0)
            
            # 更新DP表
            best_scores, best_idx = scores.max(dim=1)
            
            viterbi = best_scores + emissions[:, t]
            backptr[:, t] = best_idx
        
        # 加上end转移
        if self.end_tag_idx is not None:
            viterbi = viterbi + self.transitions[:num_tags, self.end_tag_idx]
        
        # 回溯
        best_paths = torch.zeros(batch_size, seq_len, dtype=torch.long, device=emissions.device)
        best_paths[:, -1] = viterbi.argmax(dim=1)
        
        for t in range(seq_len - 2, -1, -1):
            best_paths[:, t] = backptr[:, t + 1].gather(1, best_paths[:, t + 1].unsqueeze(1)).squeeze(1)
        
        return best_paths, viterbi.max(dim=1)[0]

3. 神经CRF

3.1 BiLSTM-CRF

结合双向LSTM和CRF的经典架构:

class BiLSTMCRF(nn.Module):
    """BiLSTM-CRF模型"""
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_tags, pad_idx=0):
        super().__init__()
        
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        
        # BiLSTM编码器
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim // 2,
            batch_first=True, bidirectional=True
        )
        
        # 发射层
        self.hidden2tag = nn.Linear(hidden_dim, num_tags)
        
        # CRF层
        self.crf = LinearCRF(num_tags)
    
    def forward(self, x, tags=None, mask=None):
        """
        x: (batch, seq_len) 输入序列
        tags: (batch, seq_len) 标签序列
        mask: (batch, seq_len) 有效位置掩码
        """
        # 嵌入
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        
        # BiLSTM编码
        lstm_out, _ = self.lstm(embedded)  # (batch, seq_len, hidden_dim)
        
        # 发射分数
        emissions = self.hidden2tag(lstm_out)  # (batch, seq_len, num_tags)
        
        # CRF前向
        if tags is not None:
            loss = self.crf(emissions, tags, mask)
            return loss
        else:
            # 解码
            return self.crf.viterbi_decode(emissions, mask)
 
 
def train_bilstm_crf():
    """训练BiLSTM-CRF"""
    # 超参数
    VOCAB_SIZE = 10000
    EMBED_DIM = 100
    HIDDEN_DIM = 256
    NUM_TAGS = 9  # BIOES标签数
    
    # 模型
    model = BiLSTMCRF(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_TAGS)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 训练循环
    for epoch in range(10):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            x, tags, mask = batch
            
            optimizer.zero_grad()
            loss = model(x, tags, mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}")
        
        # 验证
        model.eval()
        with torch.no_grad():
            for batch in dev_loader:
                x, mask = batch
                predictions, scores = model(x, mask=mask)
                # 计算F1分数
                f1 = compute_f1(predictions, batch['gold_tags'])
                print(f"F1: {f1:.4f}")

3.2 BERT-CRF

使用预训练语言模型替代BiLSTM:

class BERTCRF(nn.Module):
    """BERT-CRF模型"""
    def __init__(self, model_name, num_tags):
        super().__init__()
        
        # BERT编码器
        self.bert = BertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        
        # 发射层
        self.hidden2tag = nn.Linear(hidden_size, num_tags)
        
        # CRF层
        self.crf = LinearCRF(num_tags)
    
    def forward(self, x, attention_mask=None, tags=None):
        # BERT编码
        outputs = self.bert(input_ids=x, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        
        # 发射分数
        emissions = self.hidden2tag(sequence_output)
        
        # CRF
        if tags is not None:
            mask = attention_mask.bool()
            return self.crf(emissions, tags, mask)
        else:
            return self.crf.viterbi_decode(emissions, attention_mask.bool())

4. 应用:命名实体识别

4.1 标签体系

常用标签体系:

体系标签示例说明
BIOB-PER, I-PER, O实体开始/内部
BIOESB-PER, I-PER, E-PER, S-PER, O增加实体结束/单独实体

4.2 数据处理

from torch.utils.data import Dataset
 
class NERDataset(Dataset):
    """NER数据集"""
    def __init__(self, data, word2idx, tag2idx):
        self.data = data
        self.word2idx = word2idx
        self.tag2idx = tag2idx
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sentence, tags = self.data[idx]
        
        # 转换为索引
        x = torch.tensor([self.word2idx.get(w, 0) for w in sentence], dtype=torch.long)
        y = torch.tensor([self.tag2idx[t] for t in tags], dtype=torch.long)
        mask = torch.ones_like(x, dtype=torch.bool)
        
        return {
            'x': x,
            'y': y,
            'mask': mask
        }
 
def collate_fn(batch):
    """批处理整理函数"""
    x = torch.nn.utils.rnn.pad_sequence([b['x'] for b in batch], batch_first=True)
    y = torch.nn.utils.rnn.pad_sequence([b['y'] for b in batch], batch_first=True)
    mask = torch.nn.utils.rnn.pad_sequence([b['mask'] for b in batch], batch_first=True)
    
    return x, y, mask

4.3 评估指标

def compute_ner_f1(predictions, targets, mask):
    """计算NER的F1分数(按实体匹配)"""
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()
    mask = mask.cpu().numpy()
    
    # 转换预测和目标为实体列表
    pred_entities = batch_to_entities(predictions, mask)
    true_entities = batch_to_entities(targets, mask)
    
    # 计算TP, FP, FN
    tp = sum(1 for e in pred_entities if e in true_entities)
    fp = len(pred_entities) - tp
    fn = len(true_entities) - tp
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return f1, precision, recall
 
def batch_to_entities(tags, mask):
    """将标签序列转换为实体列表"""
    entities = []
    batch_size, seq_len = tags.shape
    
    for b in range(batch_size):
        seq_len_b = mask[b].sum()
        entity = None
        entity_type = None
        
        for t in range(seq_len_b):
            tag = tags[b, t]
            
            if tag.startswith('B-'):
                if entity is not None:
                    entities.append((entity, entity_type))
                entity = [t]
                entity_type = tag[2:]
            elif tag.startswith('I-') and entity is not None:
                entity.append(t)
            else:
                if entity is not None:
                    entities.append((tuple(entity), entity_type))
                entity = None
                entity_type = None
        
        if entity is not None:
            entities.append((tuple(entity), entity_type))
    
    return entities

5. 图结构CRF

5.1 全连接CRF

图像分割中常用的全连接CRF:

class DenseCRF(nn.Module):
    """全连接CRF"""
    def __init__(self, num_classes, spatial_std=1.0, colour_std=1.0, bilateral_std=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.spatial_std = spatial_std
        self.colour_std = colour_std
        self.bilateral_std = bilateral_std
    
    def forward(self, unary, image, max_iter=10):
        """
        unary: (H, W, num_classes) 一元势
        image: (H, W, 3) RGB图像
        """
        H, W = image.shape[:2]
        Q = unary.copy()
        
        for _ in range(max_iter):
            # 计算配分函数
            Q_flat = Q.reshape(-1, self.num_classes)
            Z = np.log(np.exp(Q_flat).sum(axis=1, keepdims=True))
            
            # 消息传递(简化版)
            message = self._message_pass(Q, image)
            
            # 更新Q
            Q = unary - message
        
        return Q
    
    def _message_pass(self, Q, image):
        """简化的消息传递"""
        # 简化:使用高斯滤波
        from scipy.ndimage import gaussian_filter
        
        message = np.zeros_like(Q)
        for c in range(self.num_classes):
            message[:, :, c] = gaussian_filter(Q[:, :, c], sigma=3)
        
        return message

5.2 神经CRF层

可学习的CRF层:

class LearnableCRFLayer(nn.Module):
    """可学习的CRF层"""
    def __init__(self, num_tags, feature_dim):
        super().__init__()
        self.num_tags = num_tags
        
        # 可学习的转移矩阵
        self.transition = nn.Parameter(torch.randn(num_tags, num_tags))
        
        # 特征到势函数的映射
        self.feature_proj = nn.Linear(feature_dim, num_tags * num_tags)
    
    def forward(self, unary, features):
        """
        unary: (batch, seq_len, num_tags)
        features: (batch, seq_len, feature_dim)
        """
        batch_size, seq_len, num_tags = unary.shape
        
        # 学习势函数
        potentials = self.feature_proj(features)  # (batch, seq_len, num_tags*num_tags)
        potentials = potentials.view(batch_size, seq_len, num_tags, num_tags)
        
        # 组合
        combined = unary.unsqueeze(-1) + potentials + self.transition
        
        return combined

6. 实践指南

6.1 何时使用CRF

场景推荐
序列标注(NER, POS)BiLSTM-CRF / BERT-CRF
图像分割CNN + 全连接CRF
词性标注CRF(特征工程)
简单序列简单神经网络

6.2 训练技巧

# 1. 标签平滑
class LabelSmoothingCRF:
    def __init__(self, num_tags, smoothing=0.1):
        self.num_tags = num_tags
        self.smoothing = smoothing
    
    def __call__(self, emissions, targets, mask):
        # 平滑标签
        smooth_targets = targets * (1 - self.smoothing) + self.smoothing / self.num_tags
        # 计算损失
        # ...
 
 
# 2. 学习率调度
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=5, factor=0.5
)
 
 
# 3. Early Stopping
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.counter = 0
        self.best_score = 0
    
    def step(self, score):
        if score > self.best_score:
            self.best_score = score
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

6.3 常见问题

问题解决方案
训练不稳定使用梯度裁剪,增加dropout
解码慢使用束搜索(beam search)
OOV问题使用字符级CNN或BERT
长序列梯度消失使用Transformer或attention

7. 相关主题

主题描述
隐马尔可夫模型生成式序列模型
马尔可夫网络与CRF无向图模型基础
循环神经网络序列建模基础
LSTM长短期记忆网络

参考

Footnotes

  1. Lafferty et al., “Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data”, ICML 2001