GAN评估指标

GAN的评估是一个具有挑战性的问题,因为生成模型没有明确的似然函数。本文档介绍主流的GAN评估指标及其实现。

1. Inception Score(IS)

1.1 核心思想

IS1利用预训练的Inception模型来评估生成样本的质量和多样性:

其中:

  • :Inception模型对生成样本 的类别预测分布
  • :边缘类别分布

1.2 直观理解

成分要求含义
低熵生成的每个样本应属于明确类别质量高
高熵各类别分布应均匀多样性好

理想状态

  • 每个样本:(低熵)
  • 总体分布:(高熵)

1.3 实现

import torch
import torch.nn.functional as F
from torchvision.models import inception_v3
from scipy import linalg
import numpy as np
 
class InceptionScore:
    """Inception Score计算"""
    def __init__(self, device='cuda'):
        self.device = device
        # 加载Inception v3模型
        self.inception = inception_v3(pretrained=True, transform_input=False)
        self.inception.eval()
        self.inception.to(device)
    
    @torch.no_grad()
    def get_predictions(self, images, batch_size=50):
        """
        获取Inception模型预测
        
        Args:
            images: (N, 3, 299, 299) Tensor
        Returns:
            预测概率分布 (N, 1000)
        """
        preds = []
        
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(self.device)
            pred = self.inception(batch)
            # 使用Softmax获取概率
            pred = F.softmax(pred, dim=-1)
            preds.append(pred.cpu())
        
        return torch.cat(preds, dim=0)
    
    def compute_is(self, images, splits=10):
        """
        计算Inception Score
        
        Args:
            images: 生成图像 (N, 3, H, W)
            splits: 分片数(用于计算均值和方差)
        Returns:
            IS均值, IS标准差
        """
        preds = self.get_predictions(images)  # (N, 1000)
        
        # 计算边缘分布 p(y)
        py = preds.mean(dim=0)  # (1000,)
        
        # 计算每个样本的KL散度
        kl = preds * (torch.log(preds) - torch.log(py.unsqueeze(0)))
        kl = kl.sum(dim=1)  # (N,)
        
        # 分割计算均值和方差
        split_scores = []
        split_size = preds.size(0) // splits
        
        for i in range(splits):
            part = kl[i*split_size:(i+1)*split_size]
            split_scores.append(torch.exp(part.mean()))
        
        return torch.mean(torch.tensor(split_scores)).item(), \
               torch.std(torch.tensor(split_scores)).item()

1.4 IS的局限性

问题说明
对类内变化不敏感无法检测类内模式崩溃
依赖Inception模型领域迁移时可能失效
不比较真实分布不评估与真实数据的相似度
可被对抗攻击欺骗存在绕过IS的生成器

2. Fréchet Inception Distance(FID)

2.1 核心思想

FID2比较生成样本分布与真实样本分布在特征空间的差异:

其中:

  • :真实/生成特征的均值
  • :真实/生成特征的协方差矩阵

核心假设:Inception特征的分布近似为多元高斯分布。

2.2 直观理解

情况FID值含义
完全相同分布0理想状态
分布偏移> 0越大偏移越大
生成质量差模糊/模式崩溃

2.3 实现

import torch
import numpy as np
from scipy import linalg
 
def calculate_fid(real_features, fake_features, eps=1e-6):
    """
    计算FID
    
    Args:
        real_features: 真实样本特征 (N, D)
        fake_features: 生成样本特征 (N, D)
    Returns:
        FID分数
    """
    # 转为numpy
    if isinstance(real_features, torch.Tensor):
        real_features = real_features.cpu().numpy()
    if isinstance(fake_features, torch.Tensor):
        fake_features = fake_features.cpu().numpy()
    
    # 计算均值
    mu_real = np.mean(real_features, axis=0)
    mu_fake = np.mean(fake_features, axis=0)
    
    # 计算协方差
    sigma_real = np.cov(real_features, rowvar=False)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    # FID公式
    diff = mu_real - mu_fake
    covmean, _ = linalg.sqrtm(sigma_real @ sigma_fake, disp=False)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    
    return fid
 
 
class FIDCalculator:
    """FID计算器"""
    def __init__(self, device='cuda'):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False)
        self.inception.eval()
        self.inception.to(device)
    
    @torch.no_grad()
    def extract_features(self, images, batch_size=50):
        """提取Inception特征"""
        features = []
        
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(self.device)
            
            # 获取pool3层特征(2048维)
            if hasattr(self.inception, 'pool3'):
                pred = self.inception(batch)
                feat = self.inception.pool3(pred)
            else:
                # Inception v3的中间层特征
                x = self.inception.Conv2d_1a_3x3(batch)
                x = self.inception.Conv2d_2a_3x3(x)
                x = self.inception.Conv2d_2b_3x3(x)
                feat = self.inception.maxpool1(F.relu(x))
                # ... 更多层
                feat = self.inception.pool5(F.relu(feat))
            
            features.append(feat.cpu())
        
        return torch.cat(features, dim=0).squeeze(-1).squeeze(-1)
    
    def compute_fid(self, real_images, fake_images):
        """计算FID"""
        real_feats = self.extract_features(real_images)
        fake_feats = self.extract_features(fake_images)
        return calculate_fid(real_feats.numpy(), fake_feats.numpy())

2.4 FID的优势与局限

优势局限
比较真实与生成分布依赖预训练Inception
对模式崩溃敏感需要足够样本量
广泛使用和接受假设高斯分布
连续可微不捕获感知质量

3. Precision, Recall和F1

3.1 动机

IS和FID无法区分质量多样性问题。Precision-Recall框架提供了更细粒度的评估。

3.2 定义

指标定义理想值
Precision(精确率)生成样本中接近真实分布的比例
Recall(召回率)真实样本中可被生成器生成的比例
F1Precision和Recall的调和平均

3.3 实现

from sklearn.neighbors import NearestNeighbors
import torch
 
class PrecisionRecall:
    """基于密度的Precision和Recall"""
    def __init__(self, k=5):
        self.k = k  # 近邻数量
    
    def compute_pr(self, real_features, fake_features, manifold='real'):
        """
        计算Precision和Recall
        
        Args:
            manifold: 'real' 或 'fake' 定义参考流形
        """
        # 获取k近邻
        if manifold == 'real':
            reference = real_features
            test = fake_features
        else:
            reference = fake_features
            test = real_features
        
        # 训练k-NN
        nn = NearestNeighbors(n_neighbors=self.k + 1, metric='euclidean')
        nn.fit(reference)
        
        # 查找每个测试点的最近邻
        distances, indices = nn.kneighbors(test)
        
        if manifold == 'real':
            # Precision: 假样本的邻居中有多少是真样本
            precision = np.mean(indices[:, 1:] < len(real_features))
        else:
            # Recall: 真样本的邻居中有多少是假样本
            # 需要重新计算
            nn_fake = NearestNeighbors(n_neighbors=self.k, metric='euclidean')
            nn_fake.fit(fake_features)
            _, indices_fake = nn_fake.kneighbors(real_features)
            recall = np.mean(indices_fake < len(fake_features))
        
        return precision if manifold == 'real' else recall
 
 
class ModeScore:
    """Mode Score:改进的IS"""
    def compute_mode_score(self, real_preds, fake_preds):
        """计算Mode Score"""
        # p(y): 真实分布
        py_real = real_preds.mean(dim=0)
        
        # p(y|c): 条件分布
        py_fake = fake_preds.mean(dim=0)
        
        # 每个样本的条件分布
        # KL(p(y|x) || p(y_c)) 的期望
        kl_real = (real_preds * (torch.log(real_preds) - torch.log(py_real.unsqueeze(0)))).sum(dim=1)
        kl_fake = (fake_preds * (torch.log(fake_preds) - torch.log(py_fake.unsqueeze(0)))).sum(dim=1)
        
        mode_score = torch.exp((kl_real.mean() + kl_fake.mean()) / 2)
        
        return mode_score.item()

4. Perceptual Path Length(PPL)

4.1 核心思想

PPL3评估潜在空间的平滑性

良好的潜在空间应该允许平滑插值,生成图像应随潜在向量线性变化。

4.2 定义

其中:

  • :潜在空间插值路径长度
  • :生成图像间的感知距离

4.3 实现

import lpips
 
class PerceptualPathLength:
    """感知路径长度"""
    def __init__(self, model='net-lin', net='vgg', device='cuda'):
        # LPIPS感知距离
        self.loss_fn = lpips.LPIPS(net=net).to(device)
        self.device = device
    
    def compute_ppl(self, generator, num_samples=10000, epsilon=1e-4):
        """
        计算PPL
        
        Args:
            generator: 生成器模型
            num_samples: 采样数量
            epsilon: 插值步长
        """
        ppls = []
        batch_size = 100
        
        for i in range(0, num_samples, batch_size):
            n = min(batch_size, num_samples - i)
            
            # 采样两个潜在向量
            z1 = torch.randn(n, generator.latent_dim, device=self.device)
            z2 = torch.randn(n, generator.latent_dim, device=self.device)
            
            # 线性插值
            direction = z2 - z1
            direction_norm = torch.norm(direction, dim=-1, keepdim=True)
            direction = direction / (direction_norm + 1e-8)
            
            # 生成端点和中间点
            with torch.no_grad():
                img1 = generator(z1)
                img2 = generator(z2)
                
                # 中间点
                z_mid = z1 + 0.5 * direction * direction_norm
                img_mid = generator(z_mid)
            
            # 计算感知距离
            dist = self.loss_fn(img1, img_mid).squeeze()
            dist_end = self.loss_fn(img_mid, img2).squeeze()
            
            # 路径长度
            path_len = (dist + dist_end) / 2
            ppls.append(path_len.mean().item())
        
        return np.mean(ppls)

5. 评估指标对比

指标测量内容范围理想值
IS质量+多样性
FID与真实分布距离低(0最优)
Precision生成质量
Recall模式覆盖
PPL潜在空间平滑性

6. 综合评估建议

6.1 推荐指标组合

def comprehensive_gan_eval(real_images, fake_images, generator):
    """
    综合评估GAN
    """
    results = {}
    
    # 1. FID(主要指标)
    fid_calc = FIDCalculator()
    results['fid'] = fid_calc.compute_fid(real_images, fake_images)
    
    # 2. IS
    is_calc = InceptionScore()
    results['is_mean'], results['is_std'] = is_calc.compute_is(fake_images)
    
    # 3. Precision & Recall
    pr_calc = PrecisionRecall()
    real_feats = fid_calc.extract_features(real_images)
    fake_feats = fid_calc.extract_features(fake_images)
    results['precision'] = pr_calc.compute_pr(real_feats, fake_feats, manifold='real')
    results['recall'] = pr_calc.compute_pr(real_feats, fake_feats, manifold='fake')
    
    # 4. PPL
    ppl_calc = PerceptualPathLength()
    results['ppl'] = ppl_calc.compute_ppl(generator)
    
    return results

6.2 评估注意事项

注意事项说明
样本数量FID建议至少5000对样本
图像预处理保持真实和生成图像一致
评估重复性多次评估取平均
综合指标单一指标不足以评估

7. 2024年新进展:FID局限性

7.1 FID的批评

2024年CVPR的研究指出FID的若干局限性:

  1. 对预训练模型的依赖:Inception特征可能不是最佳选择
  2. 不检测特定伪影:局部失真可能被全局指标掩盖
  3. 感知与指标不一致:低FID不等于高感知质量

7.2 替代方案

方法改进
DISTS基于VGG的感知指标
LPIPS学习感知相似度
SSIM/MS-SSIM结构相似度
Human Preference Score人类评估

8. 参考资料

扩展阅读:

Footnotes

  1. Salimans T, Goodfellow I, Zaremba W, et al. Improved techniques for training GANs. NeurIPS, 2016.

  2. Heusel M, Ramsauer H, Unterthiner T, et al. GANs trained by a two time-scale update rule converge to a local Nash equilibrium. NeurIPS, 2017.

  3. Karras T, Laine S, Aila T, et al. A style-based generator architecture for generative adversarial networks. CVPR, 2019.