自监督语音表示学习

1. 概述

自监督语音表示学习旨在从大量无标注音频数据中学习有意义的语音表示。与监督学习不同,自监督方法利用音频本身的结构作为监督信号,显著减少了对人工标注的依赖。

1.1 发展历程

2019: vq-wav2vec — 首个离散语音表示
2020: Wav2Vec 2.0 — 对比学习框架
2021: HuBERT — 掩码预测方法
2021: WavLM — 改进的预训练目标
2022: XLS-R — 跨语言自监督学习
2023: MMS — 大规模多语言模型

1.2 核心思想

自监督语音学习的两种主流范式:

范式代表方法核心机制优势
对比学习Wav2Vec 2.0区分正负样本简单直观
掩码预测HuBERT预测掩码位置语义更强

2. Wav2Vec 2.0

2.1 核心架构

Wav2Vec 2.0采用对比学习策略,通过区分真实的未来音频片段与负样本来学习表示:

音频 → CNN Encoder → Transformer → 量化器 → 对比损失
                        ↓
                  上下文表示 C
class Wav2Vec2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # CNN特征编码器
        self.feature_encoder = nn.Sequential(
            nn.Conv1d(1, 512, 10, 2, 5),
            nn.GroupNorm(8, 512),
            nn.GELU(),
            *[
                nn.Sequential(
                    nn.Conv1d(512, 512, 3, 2, 1),
                    nn.GroupNorm(8, 512),
                    nn.GELU()
                ) for _ in range(4)
            ],
            nn.Conv1d(512, 512, 2, 1, 0)  # 最终投影
        )
        
        # Transformer上下文网络
        self.transformer = TransformerEncoder(
            d_model=512,
            nhead=8,
            num_layers=12,
            dim_feedforward=2048
        )
        
        # 量化器 (Gumbel softmax)
        self.quantizer = VectorQuantizerEMA(
            codebook_size=320,
            dim=512,
            n_codebooks=2
        )
    
    def forward(self, waveform, mask=True):
        # 1. CNN特征提取
        features = self.feature_encoder(waveform)  # (B, T, D)
        
        # 2. 掩码
        if mask:
            features, mask_indices = self.apply_mask(features)
        
        # 3. Transformer编码
        contextual = self.transformer(features)  # (B, T, D)
        
        return contextual, features, mask_indices

2.2 对比损失

核心损失函数区分正样本(真实的 位置)与负样本:

def contrastive_loss(ctx_probs, targets, neg_indices, temperature=0.1):
    """
    ctx_probs: (B, T, 1) - 位置t的上下文与t+1匹配概率
    targets: (B, T) - 真实的codebook indices
    neg_indices: (B, T, num_neg) - 负样本indices
    """
    B, T, _ = ctx_probs.shape
    num_neg = neg_indices.shape[-1]
    
    # 获取负样本的logits
    neg_logits = torch.gather(
        neg_indices, 2, targets.unsqueeze(-1).expand(-1, -1, num_neg)
    ).view(-1, num_neg)  # (B*T, num_neg)
    
    # 正样本logits
    pos_logits = ctx_probs.view(-1, 1)  # (B*T, 1)
    
    # InfoNCE损失
    logits = torch.cat([pos_logits, neg_logits], dim=-1) / temperature
    labels = torch.zeros(B * T, dtype=torch.long, device=logits.device)
    
    loss = F.cross_entropy(logits, labels)
    
    return loss

2.3 两阶段训练

Wav2Vec 2.0采用预训练+微调两阶段:

# 预训练阶段:无监督对比学习
def pretrain_step(model, waveform):
    # 掩码
    masked_features, mask_indices = model.apply_mask(features)
    
    # 对比损失
    contextual = model.transformer(masked_features)
    quantized, indices = model.quantizer(features, mask_indices)
    
    # 多任务损失
    loss = contrastive_loss(ctx_probs, targets, negatives) + \
           diversity_loss(quantized)  # 码本多样性损失
    
    return loss
 
# 微调阶段:下游任务监督学习
def finetune_step(model, waveform, labels):
    contextual = model(waveform, mask=False)
    
    # 添加任务头
    logits = task_head(contextual)  # (B, T, num_classes)
    
    # CTC损失
    loss = F.ctc_loss(
        logits.log_softmax(-1), labels,
        input_lengths, label_lengths
    )
    
    return loss

3. HuBERT

3.1 核心思想

HuBERT (Hidden-Unit BERT) 采用掩码预测策略,类似于BERT在NLP中的应用。核心假设:相似的音频帧对应相似的”隐藏单元”(如音素)。

class HuBERTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # 特征编码器 (与Wav2Vec 2.0相同)
        self.feature_encoder = CNNEncoder()
        
        # Transformer编码器
        self.transformer = TransformerEncoder(...)
        
        # 掩码策略
        self.mask_prob = 0.65
        self.mask_len = 10  # 连续掩码长度
    
    def apply_mask(self, features):
        """连续帧掩码"""
        B, T, D = features.shape
        
        # 随机选择掩码起始位置
        mask_starts = torch.randint(
            0, T - self.mask_len, (B, T // (self.mask_len * 3),)
        )
        
        # 创建掩码
        mask = torch.zeros(B, T, dtype=torch.bool)
        for i in range(B):
            for start in mask_starts[i]:
                mask[i, start:start + self.mask_len] = True
        
        # 应用掩码
        masked_features = features.clone()
        masked_features[mask] = 0
        
        return masked_features, mask
    
    def forward(self, waveform, mask=True):
        features = self.feature_encoder(waveform)
        
        if mask:
            masked_features, mask_indices = self.apply_mask(features)
        else:
            masked_features = features
            mask_indices = None
        
        encoded = self.transformer(masked_features)
        
        return encoded, mask_indices

3.2 训练目标

HuBERT的预测目标来自离线聚类

def hubert_pretrain_loss(model, waveform, kmeans_model, num_clusters=500):
    """HuBERT预训练损失"""
    
    # 1. 获取CNN特征
    features = model.feature_encoder(waveform)
    
    # 2. 离线聚类获取伪标签
    with torch.no_grad():
        cluster_ids = kmeans_model.predict(features.cpu().numpy())
        cluster_ids = torch.from_numpy(cluster_ids).to(features.device)
    
    # 3. 应用掩码
    masked_features, mask_indices = model.apply_mask(features)
    
    # 4. Transformer编码
    encoded = model.transformer(masked_features)
    
    # 5. 预测掩码位置的聚类ID
    masked_encoded = encoded[mask_indices]  # (num_masked, D)
    
    # 分类损失
    logits = model.predictor(masked_encoded)  # (num_masked, num_clusters)
    targets = cluster_ids[mask_indices]
    
    loss = F.cross_entropy(logits, targets)
    
    return loss

3.3 Wav2Vec 2.0 vs HuBERT

特性Wav2Vec 2.0HuBERT
预训练目标对比学习掩码预测
伪标签来源在线量化器离线聚类
语义层级较低较高
收敛速度较慢较快
适合任务ASR、语音识别语音识别、说话人识别

4. WavLM

4.1 改进点

WavLM在HuBERT基础上做了三方面改进:

  1. 说话人增强:使用带噪/带混响音频预训练
  2. 改进的掩码策略:随机掩码+跨度掩码
  3. 统一预训练目标:掩码预测 + 对比学习
class WavLMModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.feature_encoder = CNNEncoder()
        self.transformer = TransformerEncoder(...)
        self.quantizer = LatentQuantizer(dim=512, codebook_size=500)
        
        # 任务头
        self.decoder = nn.Linear(512, 500)  # 掩码预测
        self contrastive_head = nn.Linear(512, 256)  # 对比学习
    
    def forward(self, waveform, noise=None):
        features = self.feature_encoder(waveform)
        
        # 添加噪声增强
        if noise is not None and self.training:
            features = features + 0.1 * noise
        
        # 掩码
        masked_features, mask = self.span_masking(features)
        
        # 编码
        encoded = self.transformer(masked_features)
        
        # 多任务损失
        masked_encoded = encoded[mask]
        
        # 掩码预测损失
        pred = self.decoder(masked_encoded)
        
        # 对比损失 (用于非掩码位置)
        non_mask_indices = ~mask
        if non_mask_indices.any():
            ctx = self.contrastive_head(encoded[non_mask_indices])
            # InfoNCE损失...
        
        return pred, ctx

5. 跨语言模型 XLS-R

5.1 设计目标

XLS-R旨在学习跨语言的通用语音表示:

  • 128种语言:涵盖全球主要语言
  • 40万小时音频:大规模预训练
  • 自适应微调:少量标注数据即可适配新语言
class XLSRModel(nn.Module):
    def __init__(self, num_languages=128):
        super().__init__()
        
        # 共享编码器
        self.shared_encoder = Wav2Vec2Encoder()
        
        # 语言特定适配
        self.language_adapters = nn.ModuleDict({
            f"lang_{i}": LanguageAdapter(dim=512)
            for i in range(num_languages)
        })
    
    def forward(self, waveform, lang_id=None):
        features = self.shared_encoder(waveform)
        
        # 语言特定适配
        if lang_id is not None and self.training:
            adapter = self.language_adapters[f"lang_{lang_id}"]
            features = adapter(features)
        
        return features

6. 下游任务微调

6.1 语音识别 (ASR)

class Wav2Vec2ForCTC(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model(config)
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        
        # CTC特定初始化
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, waveform, labels=None, label_lengths=None):
        features = self.wav2vec2(waveform, mask=False)
        
        # 线性投影到词汇表
        logits = self.decoder(self.dropout(features))  # (B, T, V)
        log_probs = F.log_softmax(logits, dim=-1)
        
        if labels is not None:
            # CTC损失
            loss = F.ctc_loss(
                log_probs.transpose(0, 1),  # (T, B, V)
                labels,
                input_lengths=torch.full((B,), T),  # 假设无下采样
                target_lengths=label_lengths,
                blank=0,
                reduction='mean'
            )
            return loss, log_probs
        
        return log_probs

6.2 说话人识别

class SpeakerRecognitionModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Wav2Vec2Model(config)
        self.pooling = AttentiveStatisticsPooling()
        self.fc = nn.Linear(config.hidden_size * 2, 512)
        self.classifier = nn.Linear(512, num_speakers)
    
    def forward(self, waveform):
        features = self.encoder(waveform, mask=False)
        
        # 注意力池化
        pooled = self.pooling(features)  # (B, D*2)
        
        # 分类
        x = self.fc(F.relu(pooled))
        logits = self.classifier(x)
        
        return logits

6.3 使用HuggingFace微调

from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Config
 
# 加载预训练模型
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
 
# 准备数据
def prepare_dataset(batch):
    audio = batch["array"]
    # 重采样到16kHz
    audio = librosa.resample(audio, orig_sr=batch["sampling_rate"], target_sr=16000)
    
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch
 
# 训练
from transformers import Trainer, TrainingArguments
 
training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned",
    per_device_train_batch_size=8,
    evaluation_strategy="steps",
    num_train_epochs=10,
    save_steps=500,
    eval_steps=500,
)
 
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    args=training_args,
)
 
trainer.train()

7. 理论分析

7.1 为什么自监督学习有效?

语音自监督学习的有效性来源于音频的内在结构

  1. 时间连续性:相邻帧高度相关
  2. 音素不变性:同一音素在不同语境下有相似声学特征
  3. 说话人一致性:同一说话人的声音具有一致性
  4. 语言规律性:语音遵循语言学规律

7.2 对比学习的收敛性

# 对比学习的梯度分析
def contrastive_gradient_analysis(ctx, pos, neg, temperature=0.1):
    """
    对比损失的梯度形式化分析
    """
    # 正样本对的相似度
    sim_pos = torch.sum(ctx * pos, dim=-1) / temperature
    
    # 负样本对的相似度
    sim_neg = torch.sum(ctx * neg, dim=-1) / temperature
    
    # 正样本对的梯度贡献
    grad_pos = pos * torch.exp(sim_pos) / (torch.exp(sim_pos) + torch.exp(sim_neg))
    
    # 负样本对的梯度贡献
    grad_neg = -ctx * torch.exp(sim_neg).unsqueeze(-1) / (torch.exp(sim_pos) + torch.exp(sim_neg))
    
    return grad_pos, grad_neg

8. 总结与展望

核心要点

  1. Wav2Vec 2.0开创了对比学习框架,通过区分正负样本学习语音表示
  2. HuBERT证明了掩码预测的有效性,离线聚类提供语义目标
  3. 大规模预训练+任务微调是当前主流范式
  4. 跨语言学习扩展了自监督方法的应用范围

未来方向

  • 更长音频的建模:处理分钟级音频
  • 多模态联合学习:结合文本和视觉信息
  • 流式适应:低延迟在线识别
  • 持续学习:不断适应新语言和新说话人

参考资料