语音识别深度学习

1. 概述

语音识别(Automatic Speech Recognition, ASR)是将语音信号转换为文本的技术。深度学习彻底改变了这一领域,从传统的GMM-HMM过渡到端到端深度神经网络模型。

1.1 发展脉络

传统方法: GMM-HMM → DNN-HMM → 端到端深度学习
                              ↓
         CTC → RNN-T → Attention-based → Hybrid CTC/Attention → Whisper

1.2 主要范式对比

范式代表模型特点优势
CTCDeep Speech条件独立假设简单、并行
RNN-TRNN-T序列到序列建模能力强
AttentionListen-Attend-Spell端到端精度高
HybridLAS + CTC联合训练稳定高效
WhisperWhisper大规模弱监督鲁棒、zero-shot

2. CTC (Connectionist Temporal Classification)

2.1 问题设定

传统ASR需要帧级对齐标注,而CTC通过引入空白符解决变长输入-输出问题:

输入:  [f1, f2, f3, f4, f5, f6, f7, f8]  (音频帧)
输出:  h  e  l  l  o                    (文本)
CTC:   hhh_eeee_ll_l_oo____            (路径)

2.2 CTC Loss

class CTCLoss(nn.Module):
    def __init__(self, blank=0, reduction='mean', zero_infinity=True):
        super().__init__()
        self.blank = blank
        self.reduction = reduction
        self.zero_infinity = zero_infinity
    
    def forward(self, log_probs, targets, input_lengths, target_lengths):
        """
        log_probs: (T, B, V) - T帧, B批次, V词汇+blank
        targets: (B, S) - 目标序列
        input_lengths: (B,) - 每条输入的长度
        target_lengths: (B,) - 每条目标的长度
        """
        T, B, V = log_probs.shape
        
        # 1. 展平并计算重复
        def collapse_repeats(x):
            """将重复字符合并"""
            collapsed = []
            prev = None
            for c in x:
                if c != prev:
                    collapsed.append(c)
                prev = c
            return collapsed
        
        # 2. 移除blank
        def remove_blank(x, blank):
            return [c for c in x if c != blank]
        
        # 3. 前向后向算法
        log_alpha = self.forward_pass(log_probs, targets, input_lengths)
        log_beta = self.backward_pass(log_probs, targets, input_lengths)
        
        # 4. 计算损失
        log_likelihood = torch.logsumexp(log_alpha + log_beta, dim=0) - \
                        torch.logsumexp(log_alpha[-1], dim=0)
        
        loss = -log_likelihood
        
        if self.reduction == 'mean':
            loss = loss / B
        elif self.reduction == 'sum':
            loss = loss.sum()
        
        return loss
    
    def forward_pass(self, log_probs, targets, input_lengths):
        """前向算法"""
        T, B, V = log_probs.shape
        S = targets.shape[1]
        
        # 初始化
        log_alpha = torch.full((T, B, S*2+1), float('-inf'))
        
        # 第一个时间步
        log_alpha[0, :, 0] = log_probs[0, :, self.blank]
        if S > 0:
            log_alpha[0, :, 1] = log_probs[0, :, targets[:, 0]]
        
        # 递归
        for t in range(1, T):
            for s in range(min(t*2+1, S*2+1)):
                # blank转移
                log_alpha[t, :, s] = torch.logsumexp(
                    log_alpha[t-1, :, s],
                    log_alpha[t-1, :, s-1] if s > 0 else float('-inf')
                ) + log_probs[t, :, self.blank]
                
                # 非blank转移
                if s > 0 and (s-1) // 2 < S:
                    idx = (s-1) // 2
                    if idx < targets.shape[1] and (s % 2 == 1 or targets[:, idx] != targets[:, idx-1] if idx > 0 else True):
                        log_alpha[t, :, s] = torch.logsumexp(
                            log_alpha[t, :, s],
                            log_alpha[t-1, :, s-2] if s >= 2 else float('-inf')
                        ) + log_probs[t, :, targets[:, idx]]
        
        return log_alpha

2.3 CTC解码

def ctc_decode(log_probs, method='greedy', beam_width=10, blank=0):
    """
    CTC解码方法
    """
    if method == 'greedy':
        return ctc_greedy_decode(log_probs, blank)
    elif method == 'beam':
        return ctc_beam_decode(log_probs, beam_width, blank)
    elif method == 'prefix':
        return ctc_prefix_beam_decode(log_probs, beam_width, blank)
 
def ctc_greedy_decode(log_probs, blank=0):
    """贪婪解码:选择每个时间步最可能的非blank字符"""
    predictions = torch.argmax(log_probs, dim=-1)  # (T, B)
    
    # 移除blank和重复
    decoded = []
    for batch in predictions.t():
        batch = batch.cpu().numpy()
        collapsed = []
        prev = None
        for p in batch:
            if p != blank and p != prev:
                collapsed.append(p)
            prev = p
        decoded.append(collapsed)
    
    return decoded
 
def ctc_prefix_beam_decode(log_probs, beam_width=10, blank=0):
    """前缀束搜索:保留所有可能的前缀"""
    T, B, V = log_probs.shape
    beams = [{'prefix': (), 'score': 0.0, 'blank': True}]  # (prefix, score, last_char_blank)
    
    for t in range(T):
        new_beams = {}
        
        for beam in beams:
            prefix, score, last_blank = beam['prefix'], beam['score'], beam['blank']
            
            for c in range(V):
                log_prob = log_probs[t, 0, c].item()
                new_score = score + log_prob
                
                if c == blank:
                    # Blank转移
                    key = (prefix, True)
                    new_beams[key] = max(
                        new_beams.get(key, {'prefix': prefix, 'score': float('-inf'), 'blank': True}),
                        {'prefix': prefix, 'score': new_score, 'blank': True}
                    )
                elif last_blank or len(prefix) == 0 or prefix[-1] != c:
                    # 非blank转移(需要与前一个不同)
                    new_prefix = prefix + (c,)
                    key = (new_prefix, False)
                    new_beams[key] = max(
                        new_beams.get(key, {'prefix': new_prefix, 'score': float('-inf'), 'blank': False}),
                        {'prefix': new_prefix, 'score': new_score, 'blank': False}
                    )
        
        # 保留top-k beams
        beams = sorted(new_beams.values(), key=lambda x: x['score'], reverse=True)[:beam_width]
    
    # 返回最高分beam
    return beams[0]['prefix']

3. RNN-T (RNN Transducer)

3.1 架构设计

RNN-T由三个组件组成:

  • Encoder(编码器):声学模型,处理输入音频
  • Prediction Network(预测网络):语言模型,处理已有输出
  • Joint Network(联合网络):融合两个输入,预测下一个token
class RNNT(nn.Module):
    def __init__(self, vocab_size, enc_dim=512, pred_dim=512, joint_dim=512):
        super().__init__()
        
        # 编码器
        self.encoder = SpeechEncoder(
            input_dim=80,  # Mel滤波器组维度
            enc_dim=enc_dim,
            num_layers=6
        )
        
        # 预测网络(语言模型)
        self.prediction = PredictionNetwork(
            vocab_size=vocab_size,
            embed_dim=512,
            pred_dim=pred_dim
        )
        
        # 联合网络
        self.joint = JointNetwork(
            enc_dim=enc_dim,
            pred_dim=pred_dim,
            joint_dim=joint_dim,
            vocab_size=vocab_size
        )
    
    def forward(self, inputs, input_lengths, targets, target_lengths):
        # 编码器
        enc_out = self.encoder(inputs)  # (T, B, enc_dim)
        
        # 预测网络
        target_emb = self.prediction(targets)  # (U, B, pred_dim)
        
        # 联合网络
        logits = self.joint(enc_out, target_emb)  # (T, U, B, vocab_size)
        
        return logits
 
class SpeechEncoder(nn.Module):
    """双向LSTM编码器"""
    def __init__(self, input_dim, enc_dim, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim, enc_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)
        return out
 
class PredictionNetwork(nn.Module):
    """预测网络:类似语言模型"""
    def __init__(self, vocab_size, embed_dim, pred_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, pred_dim,
            num_layers=2,
            batch_first=True
        )
    
    def forward(self, targets):
        # 目标左移一位(teacher forcing)
        emb = self.embedding(targets[:, :-1])
        out, _ = self.lstm(emb)
        return out
 
class JointNetwork(nn.Module):
    """联合网络:融合编码器输出和预测输出"""
    def __init__(self, enc_dim, pred_dim, joint_dim, vocab_size):
        super().__init__()
        self.enc_proj = nn.Linear(enc_dim, joint_dim)
        self.pred_proj = nn.Linear(pred_dim, joint_dim)
        self.fc = nn.Sequential(
            nn.Tanh(),
            nn.Linear(joint_dim, vocab_size)
        )
    
    def forward(self, enc_out, pred_out):
        # enc_out: (T, B, enc_dim)
        # pred_out: (U, B, pred_dim)
        T, B, _ = enc_out.shape
        U, _, _ = pred_out.shape
        
        # 广播
        enc_expanded = enc_out.unsqueeze(2).expand(T, B, U, -1)  # (T, B, U, enc_dim)
        pred_expanded = pred_out.unsqueeze(0).expand(T, -1, B, -1)  # (T, U, B, pred_dim)
        
        # 融合
        combined = torch.tanh(self.enc_proj(enc_expanded) + self.pred_proj(pred_expanded))
        logits = self.fc(combined)  # (T, U, B, vocab_size)
        
        return logits

3.2 RNN-T Loss

def rnnt_loss(logits, targets, input_lengths, target_lengths):
    """
    RNN-T损失:类似CTC但考虑输出之间的依赖
    """
    T, U, B, V = logits.shape
    
    # 转换为对数概率
    log_probs = F.log_softmax(logits, dim=-1)
    
    # 前向后向算法
    loss = 0.0
    for b in range(B):
        T_b = input_lengths[b].item()
        U_b = target_lengths[b].item()
        target_b = targets[b, :U_b]
        
        # 动态规划计算
        alpha = torch.full((T_b, U_b + 1), float('-inf'))
        
        # 初始化
        alpha[0, 0] = log_probs[0, 0, b, 0]  # blank at (0,0)
        
        # 递归
        for t in range(T_b):
            for u in range(U_b + 1):
                if t > 0:
                    # 发射转移:保持u,使用blank
                    alpha[t, u] = torch.logaddexp(
                        alpha[t, u], alpha[t-1, u] + log_probs[t, u, b, 0]
                    )
                
                if u > 0:
                    # 标签转移:消耗一个标签
                    alpha[t, u] = torch.logaddexp(
                        alpha[t, u], alpha[t, u-1] + log_probs[t, u-1, b, target_b[u-1]]
                    )
        
        # 终止:最后必须是blank
        loss -= alpha[T_b - 1, U_b] + log_probs[T_b - 1, U_b, b, 0]
    
    return loss / B

4. Listen-Attend-Spell (LAS)

4.1 序列到序列架构

LAS是完全端到端的注意力模型,包含三个组件:

  1. Listener:CNN + RNN编码器,将音频编码为高级表示
  2. Attender:注意力机制,对齐输入和输出
  3. Speller:RNN解码器,生成输出文本
class ListenAttendSpell(nn.Module):
    def __init__(self, vocab_size, enc_dim=256, att_dim=256, dec_dim=512):
        super().__init__()
        
        # Listener编码器
        self.listener = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # 下采样
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            *[
                nn.Sequential(
                    nn.Linear(enc_dim, enc_dim * 4),
                    nn.LayerNorm(enc_dim * 4),
                    nn.ReLU(),
                    nn.Linear(enc_dim * 4, enc_dim)
                ) for _ in range(4)
            ]
        )
        
        # Attender注意力
        self.attender = DotProductAttention(enc_dim, att_dim)
        
        # Speller解码器
        self.speller = nn.LSTM(
            vocab_size + enc_dim, dec_dim,
            num_layers=2,
            batch_first=True
        )
        self.fc = nn.Linear(dec_dim, vocab_size)
    
    def forward(self, inputs, targets, target_lengths):
        # 1. 编码
        h = self.listener(inputs)  # (B, T, enc_dim)
        
        # 2. 解码
        decoder_input = self._prepare_decoder_input(targets)
        outputs, _ = self.speller(decoder_input)
        
        # 3. 预测
        logits = self.fc(outputs)  # (B, U, vocab_size)
        
        return logits
    
    def _prepare_decoder_input(self, targets):
        """准备解码器输入:嵌入 + 上下文向量"""
        # 目标嵌入
        target_emb = self.embedding(targets[:, :-1])
        return target_emb

4.2 注意力机制

class DotProductAttention(nn.Module):
    """点积注意力"""
    def __init__(self, query_dim, key_dim, att_dim=256):
        super().__init__()
        self.query_proj = nn.Linear(query_dim, att_dim)
        self.key_proj = nn.Linear(key_dim, att_dim)
        self.value_proj = nn.Linear(key_dim, query_dim)
        self.scale = math.sqrt(att_dim)
    
    def forward(self, query, keys, values=None, mask=None):
        """
        query: (B, 1, query_dim) - 当前解码器状态
        keys: (B, T, key_dim) - 编码器输出
        """
        if values is None:
            values = keys
        
        # 投影
        Q = self.query_proj(query)  # (B, 1, att_dim)
        K = self.key_proj(keys)     # (B, T, att_dim)
        V = self.value_proj(values) # (B, T, query_dim)
        
        # 计算注意力分数
        scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale  # (B, 1, T)
        
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        # 注意力权重
        attn_weights = F.softmax(scores, dim=-1)  # (B, 1, T)
        
        # 加权和
        context = torch.bmm(attn_weights, V)  # (B, 1, query_dim)
        
        return context, attn_weights

5. Hybrid CTC/Attention

5.1 联合训练优势

Hybrid方法结合CTC和Attention的优点:

  • CTC提供单调对齐约束,加速训练
  • Attention提供更灵活的建模能力
  • 两种损失联合优化,互相正则化
class HybridASR(nn.Module):
    def __init__(self, vocab_size, enc_dim=512, num_layers=6):
        super().__init__()
        
        # 共享编码器
        self.encoder = nn.LSTM(
            80, enc_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )
        
        # CTC头
        self.ctc_fc = nn.Linear(enc_dim * 2, vocab_size)
        
        # Attention解码器
        self.decoder = AttentionDecoder(vocab_size, enc_dim)
    
    def forward(self, inputs, input_lengths, targets, target_lengths):
        # 编码
        h, _ = self.encoder(inputs)  # (B, T, enc_dim*2)
        
        # CTC损失
        ctc_logits = self.ctc_fc(h)  # (B, T, vocab_size)
        ctc_log_probs = F.log_softmax(ctc_logits, dim=-1)
        ctc_loss = F.ctc_loss(
            ctc_log_probs.transpose(0, 1),  # (T, B, V)
            targets, input_lengths, target_lengths,
            blank=0
        )
        
        # Attention损失
        attn_logits = self.decoder(h, targets)  # (B, U, vocab_size)
        attn_loss = F.cross_entropy(
            attn_logits.view(-1, vocab_size),
            targets[:, 1:].reshape(-1)  # 忽略bos
        )
        
        # 联合损失
        loss = ctc_loss + attn_loss
        
        return loss, ctc_loss, attn_loss

6. Whisper

6.1 核心设计

OpenAI的Whisper是大规模弱监督预训练的代表作:

  • 680万小时音频:大规模弱标注数据
  • 多任务学习:ASR + 翻译 + 语言识别
  • 编码器-解码器:标准的Transformer架构
  • 强鲁棒性:无需微调即可处理各种音频
class Whisper(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.dims = dims
        
        # 编码器
        self.encoder = WhisperEncoder(
            n_mels=dims.n_mels,
            n_ctx=dims.n_audio_ctx,
            n_state=dims.n_audio_state,
            n_head=dims.n_audio_head,
            n_layer=dims.n_audio_layer
        )
        
        # 解码器
        self.decoder = WhisperDecoder(
            n_vocab=dims.n_vocab,
            n_ctx=dims.n_text_ctx,
            n_state=dims.n_text_state,
            n_head=dims.n_text_head,
            n_layer=dims.n_text_layer
        )
    
    def forward(self, mel, tokens):
        """前向传播"""
        # 编码
        x = self.encoder(mel)
        
        # 解码
        x = self.decoder(tokens, x)
        
        return x
    
    @torch.no_grad()
    def generate(self, mel, max_length=448, temperature=1.0):
        """自回归生成"""
        B = mel.shape[0]
        tokens = torch.full((B, 1), 50257, dtype=torch.long, device=mel.device)  # <|startoftranscript|>
        
        for _ in range(max_length):
            logits = self.forward(mel, tokens)
            next_token = self._sample(logits[:, -1], temperature)
            tokens = torch.cat([tokens, next_token], dim=1)
            
            # 结束检测
            if (next_token == 50256).all():  # <|endoftext|>
                break
        
        return tokens

6.2 多任务格式

Whisper使用特殊token表示不同任务:

TASK_TOKENS = {
    '<|transcribe|>': 50359,  # 转录
    '<|translate|>': 50358,  # 翻译
    '<|en|>': 50259,         # 英语
    '<|zh|>': 50260,         # 中文
    '<|es|>': 50261,         # 西班牙语
    # ... 更多语言
}
 
# 示例:英语转录
# Input: mel_spectrogram
# Output: <|startoftranscript|><|transcribe|><|en|>Hello world<|endoftext|>

6.3 Whisper变体

模型参数量用途
Whisper tiny39M快速推理
Whisper base74M平衡性能
Whisper small244M较高精度
Whisper medium769M高精度
Whisper large1550M最高精度

7. 流式识别

7.1 挑战

流式ASR要求实时处理,需要:

  • 低延迟:< 200ms
  • 增量输出:不等待完整音频
  • 单向RNN:无未来信息

7.2 方法

class StreamingASR(nn.Module):
    """流式ASR:使用chunk处理"""
    def __init__(self, chunk_size=30):  # 30帧 ≈ 300ms
        super().__init__()
        self.chunk_size = chunk_size
        self.model = WhisperModel()
        self.context_buffer = None
    
    def process_chunk(self, audio_chunk):
        """处理一个音频块"""
        mel = self.compute_mel(audio_chunk)
        
        if self.context_buffer is None:
            self.context_buffer = mel
        else:
            # 保留右侧上下文
            self.context_buffer = torch.cat([self.context_buffer[:, -15:], mel], dim=1)
        
        # 预测
        tokens = self.model.forward(self.context_buffer)
        
        return tokens
    
    def reset(self):
        """重置状态"""
        self.context_buffer = None

8. 实践指南

8.1 使用HuggingFace Transformers

from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
 
# 加载模型
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
 
# 准备输入
def prepare_dataset(batch):
    audio = batch["audio"]
    input_features = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    ).input_features
    
    batch["input_features"] = input_features[0]
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    return batch
 
# 生成
model.eval()
with torch.no_grad():
    generated_ids = model.generate(
        input_features,
        task="transcribe",
        language="en",
        max_new_tokens=128
    )
    
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)

8.2 评估指标

def compute_wer(predictions, references):
    """计算词错误率 (WER)"""
    total_words = 0
    total_errors = 0
    
    for pred, ref in zip(predictions, references):
        pred_words = pred.split()
        ref_words = ref.split()
        
        # 编辑距离
        errors = levenshtein_distance(pred_words, ref_words)
        total_errors += errors
        total_words += len(ref_words)
    
    return total_errors / total_words
 
def levenshtein_distance(s1, s2):
    """计算两个序列的编辑距离"""
    m, n = len(s1), len(s2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if s1[i-1] == s2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
    
    return dp[m][n]

9. 总结

核心要点

  1. CTC通过空白符解决了序列对齐问题,是早期端到端模型的主流
  2. RNN-T通过预测网络建模输出依赖,提高了建模能力
  3. Attention机制提供了更灵活的输入-输出对齐
  4. Hybrid CTC/Attention结合了两种方法的优势
  5. Whisper证明了大规模弱监督的有效性

发展趋势

  • 更大规模的预训练:利用海量无标注音频
  • 多任务统一:ASR、翻译、说话人识别统一建模
  • 流式实时:端到端流式识别
  • 多模态融合:视觉+语音联合理解

参考资料