音频扩散模型

1. 概述

扩散模型(Diffusion Models)已成功应用于图像生成,其核心思想也被扩展到音频领域。音频扩散模型面临独特挑战:音频序列极长、高采样率要求、同时需要保证语义和声学质量。

1.1 音频生成的独特挑战

挑战描述解决方案
序列长度音频采样率16-48kHz,1秒=16000-48000个样本频域表示、压缩表示
时序连贯性音频帧之间高度相关自回归解码、层次生成
双尺度建模需要同时保证语义和声学质量多阶段生成、语义token
实时生成语音交互需要低延迟蒸馏、推测解码

1.2 发展历程

2021: DiffWave — 首个神经声码器级扩散
2022: DiffAudio — 文本到音频扩散
2023: AudioGen — 自回归+扩散混合
2024: MusicGen — 音乐生成的多码本语言模型
2024: Stable Audio — 开源大规模音频生成

2. 扩散模型基础回顾

2.1 连续时间扩散

class AudioDiffusion(nn.Module):
    def __init__(self, n_mels=80, diffusion_steps=1000, beta_schedule='linear'):
        super().__init__()
        self.n_mels = n_mels
        self.diffusion_steps = diffusion_steps
        
        # 噪声调度
        betas = self.get_beta_schedule(beta_schedule, diffusion_steps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        
        # 去噪网络
        self.denoiser = AudioUnet(n_mels)
    
    def get_beta_schedule(self, schedule, T):
        if schedule == 'linear':
            return torch.linspace(1e-4, 0.02, T)
        elif schedule == 'cosine':
            s = 0.008
            x = torch.linspace(0, T, T + 1)
            alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * torch.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            return torch.clip(1 - alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
    
    def forward_diffusion(self, x0, t, noise=None):
        """前向过程:加噪声"""
        if noise is None:
            noise = torch.randn_like(x0)
        
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1)
        
        xt = sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise
        return xt
    
    @torch.no_grad()
    def reverse_diffusion(self, xt, condition=None):
        """反向过程:去噪"""
        for t in reversed(range(self.diffusion_steps)):
            t_tensor = torch.tensor([t], device=xt.device)
            
            # 预测噪声
            noise_pred = self.denoiser(xt, t_tensor, condition)
            
            # 计算均值
            x0_pred = (xt - sqrt_one_minus_alphas_cumprod[t] * noise_pred) / sqrt_alphas_cumprod[t]
            
            # 采样
            if t > 0:
                xt = sqrt_alphas[t-1] * x0_pred + sqrt_one_minus_alphas[t-1] * noise_pred
            else:
                xt = x0_pred
        
        return xt

3. DiffWave:波形级扩散声码器

3.1 架构设计

DiffWave使用扩散模型进行波形生成:

class DiffWave(nn.Module):
    def __init__(self, n_mels=80, n_layers=30, n_channels=512):
        super().__init__()
        
        # 条件投影
        self.condition_proj = nn.Sequential(
            nn.Conv1d(n_mels, n_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(n_channels, n_channels, 3, padding=1)
        )
        
        # 输入投影
        self.input_proj = nn.Conv1d(1, n_channels, 3, padding=1)
        
        # 残差块
        self.resblocks = nn.ModuleList([
            ResidualBlock(n_channels, n_mels)
            for _ in range(n_layers)
        ])
        
        # 输出层
        self.output_proj = nn.Conv1d(n_channels, 1, 3, padding=1)
    
    def forward(self, x, mel, noise=None, t=None):
        """
        x: (B, 1, T) - 音频或噪声
        mel: (B, n_mels, T) - 梅尔频谱条件
        t: (B,) - 时间步
        """
        # 条件编码
        h_cond = self.condition_proj(mel)  # (B, C, T)
        
        # 输入编码
        h = self.input_proj(x)  # (B, C, T)
        
        # 残差块
        for block in self.resblocks:
            h = block(h, h_cond, t)
        
        # 输出
        output = self.output_proj(h)  # (B, 1, T)
        
        return output
 
class ResidualBlock(nn.Module):
    def __init__(self, channels, n_mels):
        super().__init__()
        self.dilated_conv = nn.Conv1d(
            channels, 2 * channels, 3, padding=2
        )
        self.condition_proj = nn.Conv1d(n_mels, 2 * channels, 1)
        
        # 输出投影
        self.res_conv = nn.Conv1d(channels, channels, 1)
        self.skip_conv = nn.Conv1d(channels, channels, 1)
    
    def forward(self, x, condition, t):
        # 时间步嵌入
        t_emb = get_timestep_embedding(t, x.shape[1])
        
        # 门控卷积
        h = self.dilated_conv(x)
        
        # 条件调制
        cond = self.condition_proj(condition)
        h = h + cond
        h = h + t_emb
        
        # 门控
        gate, filter = h.chunk(2, dim=1)
        h = torch.sigmoid(gate) * torch.tanh(filter)
        
        # 残差和跳跃连接
        x = self.res_conv(h) + x
        skip = self.skip_conv(h)
        
        return x, skip
 
def get_timestep_embedding(t, dim):
    """时间步嵌入"""
    half_dim = dim // 2
    embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
    embeddings = t[:, None].float() * embeddings[None, :]
    embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
    return embeddings.transpose(1, 2)  # (B, dim, 1)

4. AudioGen:自回归+扩散混合

4.1 核心设计

AudioGen采用两阶段生成:

  1. 语义token:自回归生成语义表示
  2. 声学token:扩散模型细化声学细节
class AudioGen(nn.Module):
    def __init__(self, semantic_vocab_size, acoustic_codebook_size, 
                 n_acoustic_layers=12):
        super().__init__()
        
        # 语义模型 (自回归)
        self.semantic_model = AutoregressiveModel(
            vocab_size=semantic_vocab_size,
            d_model=1024
        )
        
        # 声学模型 (扩散)
        self.acoustic_model = AcousticDiffusionModel(
            codebook_size=acoustic_codebook_size,
            n_layers=n_acoustic_layers
        )
        
        # 语义解码器
        self.semantic_decoder = SemanticDecoder()
    
    def forward(self, text_emb, semantic_tokens=None, acoustic_tokens=None):
        # 阶段1: 语义token自回归生成
        if semantic_tokens is None:
            semantic_tokens = self.generate_semantic(text_emb)
        
        # 阶段2: 语义token条件扩散生成声学token
        if acoustic_tokens is None:
            acoustic_tokens = self.acoustic_model.sample(
                condition=semantic_tokens
            )
        
        # 解码为波形
        waveform = self.semantic_decoder(semantic_tokens, acoustic_tokens)
        
        return waveform
    
    def generate_semantic(self, text_emb, max_len=600):
        """自回归生成语义token"""
        tokens = torch.zeros(1, 1, dtype=torch.long, device=text_emb.device)
        
        for _ in range(max_len):
            logits = self.semantic_model(tokens, text_emb)
            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
            tokens = torch.cat([tokens, next_token], dim=1)
            
            if (next_token == 0).all():  # EOS
                break
        
        return tokens

4.2 声学扩散模型

class AcousticDiffusionModel(nn.Module):
    def __init__(self, codebook_size, n_layers=12, d_model=1024):
        super().__init__()
        
        self.d_model = d_model
        
        # 时间步嵌入
        self.time_emb = nn.Sequential(
            nn.Linear(256, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Transformer层
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads=16)
            for _ in range(n_layers)
        ])
        
        # 输出投影
        self.output_proj = nn.Linear(d_model, codebook_size)
    
    def forward(self, x, t, condition):
        """去噪训练"""
        # 时间步嵌入
        t_emb = self.get_timestep_embedding(t)
        h = self.time_emb(t_emb)
        
        # 条件注入
        if condition is not None:
            h = h + condition
        
        # Transformer处理
        for layer in self.layers:
            h = layer(h, condition)
        
        # 输出logits
        logits = self.output_proj(h)
        
        return logits
    
    @torch.no_grad()
    def sample(self, condition, n_steps=50):
        """DDPM采样"""
        B = condition.size(0)
        
        # 初始化为噪声
        x = torch.randn(B, condition.size(1), self.d_model, device=condition.device)
        
        for t in reversed(range(n_steps)):
            # 预测
            logits = self.forward(x, t, condition)
            
            # 采样下一步
            if t > 0:
                noise = torch.randn_like(x)
                x = self.ddim_step(logits, t, noise)
            else:
                x = logits
        
        return x

5. MusicGen:音乐生成

5.1 核心架构

Google的MusicGen采用多码本语言模型架构:

class MusicGen(nn.Module):
    def __init__(self, vocab_size=2048, n_codebooks=4, d_model=2048):
        super().__init__()
        
        # 文本编码
        self.text_encoder = T5Encoder()
        
        # 音频编码器(语义层)
        self.audio_encoder = EncodecEncoder()
        
        # Transformer解码器
        self.decoder = nn.TransformerDecoder(
            d_model=d_model,
            nhead=16,
            num_layers=48,
            dim_feedforward=4 * d_model
        )
        
        # 输出头(多个码本)
        self.codebook_heads = nn.ModuleList([
            nn.Linear(d_model, vocab_size)
            for _ in range(n_codebooks)
        ])
        
        self.vocab_size = vocab_size
    
    def forward(self, text, audio_tokens=None):
        # 1. 文本编码
        text_emb = self.text_encoder(text)  # (B, T_text, D)
        
        # 2. 音频编码(如果有)
        if audio_tokens is not None:
            audio_emb = self.audio_encoder.decode(audio_tokens)  # (B, T_audio, D)
        else:
            audio_emb = None
        
        # 3. Transformer解码
        output = self.decoder(audio_emb, text_emb)
        
        # 4. 多个码本输出
        logits_list = [head(output) for head in self.codebook_heads]
        
        return logits_list
    
    @torch.no_grad()
    def generate(self, text, max_len=1500, temperature=1.0):
        """自回归生成"""
        # 文本编码
        text_emb = self.text_encoder(text)
        
        # 初始化
        B = text_emb.size(0)
        generated = torch.zeros(B, 0, dtype=torch.long, device=text_emb.device)
        
        # 多码本自回归
        for t in range(max_len):
            # 获取解码器输出
            logits = self.decoder(generated, text_emb)
            
            # 对每个码本采样
            tokens_t = []
            for i, head in enumerate(self.codebook_heads):
                logits_i = logits[:, -1]  # 当前时间步
                
                # 从对应的码本采样
                next_token = torch.multinomial(
                    F.softmax(logits_i / temperature, dim=-1), 1
                )
                tokens_t.append(next_token)
            
            # 打包所有码本
            token_t = torch.stack(tokens_t, dim=1)  # (B, n_codebooks)
            generated = torch.cat([generated, token_t], dim=1)
        
        return generated

5.2 训练目标

def musicgen_loss(model, text, audio_tokens):
    """
    MusicGen的训练损失
    audio_tokens: (B, n_codebooks, T)
    """
    # 前向传播
    logits_list = model(text, audio_tokens[:, :, :-1])  # teacher forcing
    
    # 每个码本的交叉熵损失
    total_loss = 0
    for i, logits in enumerate(logits_list):
        target = audio_tokens[:, i, 1:]  # 目标token
        loss = F.cross_entropy(
            logits.reshape(-1, model.vocab_size),
            target.reshape(-1)
        )
        total_loss += loss
    
    return total_loss / len(logits_list)

6. Stable Audio

6.1 开源音频生成

Stability AI的Stable Audio采用类似MusicGen的架构:

class StableAudio(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 文本编码
        self.text_encoder = T5Model.from_pretrained('google/t5-v1_1-small')
        
        # 条件编码
        self.condition_encoder = nn.Sequential(
            nn.Conv1d(1024, 1536, 3, padding=1),
            nn.ResBlock(1536, [3, 5]),
            nn.ResBlock(1536, [3, 5]),
            nn.ResBlock(1536, [3, 5]),
        )
        
        # 因果Transformer
        self.transformer = CausalTransformer(
            d_model=1536,
            nhead=24,
            num_layers=36,
            max_len=4000
        )
        
        # 输出头
        self.output_head = AudioOutputHead(
            codebook_size=2048,
            n_codebooks=8
        )
    
    def forward(self, text, audio_tokens=None):
        # 文本编码
        text_emb = self.text_encoder(text).last_hidden_state  # (B, T_text, 1024)
        
        # 条件编码
        cond = text_emb.transpose(1, 2)  # (B, 1024, T_text)
        cond = self.condition_encoder(cond).transpose(1, 2)  # (B, T_text, 1536)
        
        # 音频生成
        output = self.transformer(audio_tokens, cond)
        
        # 输出logits
        logits = self.output_head(output)
        
        return logits

7. 音频扩散的独特技术

7.1 频域扩散

class FrequencyDiffusion(nn.Module):
    """在频域进行扩散"""
    def __init__(self):
        super().__init__()
        self.denoiser = FrequencyUnet()
    
    def forward_diffusion(self, waveform, t):
        """在频域加噪声"""
        # STFT
        stft = torch.stft(
            waveform.squeeze(1),
            n_fft=2048,
            hop_length=512,
            return_complex=True
        )  # (B, F, T)
        
        # 在幅度谱上加噪声
        magnitude = stft.abs()
        noise = torch.randn_like(magnitude)
        magnitude_noisy = magnitude * self.alphas_cumprod[t] + \
                          noise * self.sqrt_one_minus_alphas_cumprod[t]
        
        # 相位保持
        phase = stft.angle()
        
        return magnitude_noisy, phase
    
    def forward(self, magnitude_noisy, t, condition):
        """去噪"""
        noise_pred = self.denoiser(magnitude_noisy, t, condition)
        
        # 恢复原始幅度
        x0_pred = (magnitude_noisy - noise_pred * self.sqrt_one_minus_alphas_cumprod[t]) / \
                   self.sqrt_alphas_cumprod[t]
        
        return x0_pred

7.2 层次生成

class HierarchicalAudioDiffusion(nn.Module):
    """层次化音频生成"""
    def __init__(self):
        super().__init__()
        
        # 全局层 (语义)
        self.global_model = GlobalDiffusionModel()
        
        # 局部层 (声学)
        self.local_model = LocalDiffusionModel()
    
    def generate(self, text, n_local_steps=50):
        # 1. 全局层:生成语义表示
        global_latent = self.global_model.sample(
            condition=text,
            n_steps=100
        )  # 低分辨率
        
        # 2. 局部层:逐步细化
        local_latent = global_latent
        for scale in range(n_local_steps):
            local_latent = self.local_model.sample(
                condition=local_latent,
                guidance=text,
                n_steps=10
            )
        
        # 3. 波形解码
        waveform = self.waveform_decoder(local_latent)
        
        return waveform

8. 实践指南

8.1 使用HuggingFace Diffusers

from diffusers import AudioLDMPipeline
import torch
 
# 加载模型
pipe = AudioLDMPipeline.from_pretrained(
    "stabilityai/stable-audio-open-ait",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
 
# 生成
prompt = "Techno music with a strong, unwavering beat and synth"
audio = pipe(
    prompt,
    duration=8,
    num_inference_steps=50,
    audio_length_in_s=8
).audios[0]
 
# 保存
import scipy.io.wavfile as wav
wav.write("techno_music.wav", 16000, audio)

8.2 音频质量评估

def evaluate_audio_generation(generated, ground_truth=None, sr=16000):
    """音频生成质量评估"""
    metrics = {}
    
    # 1. Fréchet Audio Distance (FAD)
    # 需要预训练的音频嵌入模型
    # metrics['FAD'] = compute_fad(generated, ground_truth)
    
    # 2. KL Divergence
    metrics['KL_Div'] = compute_kl_divergence(generated, ground_truth)
    
    # 3. 音频质量指标
    if ground_truth is not None:
        # 信噪比
        metrics['SNR'] = 10 * np.log10(
            np.sum(ground_truth**2) / np.sum((generated - ground_truth)**2)
        )
        
        # 相关性
        metrics['Corr'] = np.corrcoef(generated, ground_truth)[0, 1]
    
    # 4. 主观指标(需要人工评估)
    # metrics['MOS'] = collect_mos_scores(generated)
    
    return metrics

9. 总结

核心要点

  1. 音频扩散模型面临序列长度和双尺度建模的独特挑战
  2. DiffWave开创了波形级扩散声码器
  3. AudioGen和MusicGen采用混合架构:自回归+扩散
  4. 层次化生成和频域处理是常用优化策略

未来方向

  • 更长的音频生成:分钟级音乐和语音
  • 实时生成:蒸馏和推测解码
  • 统一模型:一个模型处理语音、音乐、声音效果
  • 可控生成:精确控制风格、情感、乐器

参考资料