GAN评估指标
GAN的评估是一个具有挑战性的问题,因为生成模型没有明确的似然函数。本文档介绍主流的GAN评估指标及其实现。
1. Inception Score(IS)
1.1 核心思想
IS1利用预训练的Inception模型来评估生成样本的质量和多样性:
其中:
- :Inception模型对生成样本 的类别预测分布
- :边缘类别分布
1.2 直观理解
| 成分 | 要求 | 含义 |
|---|---|---|
| 低熵 | 生成的每个样本应属于明确类别 | 质量高 |
| 高熵 | 各类别分布应均匀 | 多样性好 |
理想状态:
- 每个样本:(低熵)
- 总体分布:(高熵)
1.3 实现
import torch
import torch.nn.functional as F
from torchvision.models import inception_v3
from scipy import linalg
import numpy as np
class InceptionScore:
"""Inception Score计算"""
def __init__(self, device='cuda'):
self.device = device
# 加载Inception v3模型
self.inception = inception_v3(pretrained=True, transform_input=False)
self.inception.eval()
self.inception.to(device)
@torch.no_grad()
def get_predictions(self, images, batch_size=50):
"""
获取Inception模型预测
Args:
images: (N, 3, 299, 299) Tensor
Returns:
预测概率分布 (N, 1000)
"""
preds = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size].to(self.device)
pred = self.inception(batch)
# 使用Softmax获取概率
pred = F.softmax(pred, dim=-1)
preds.append(pred.cpu())
return torch.cat(preds, dim=0)
def compute_is(self, images, splits=10):
"""
计算Inception Score
Args:
images: 生成图像 (N, 3, H, W)
splits: 分片数(用于计算均值和方差)
Returns:
IS均值, IS标准差
"""
preds = self.get_predictions(images) # (N, 1000)
# 计算边缘分布 p(y)
py = preds.mean(dim=0) # (1000,)
# 计算每个样本的KL散度
kl = preds * (torch.log(preds) - torch.log(py.unsqueeze(0)))
kl = kl.sum(dim=1) # (N,)
# 分割计算均值和方差
split_scores = []
split_size = preds.size(0) // splits
for i in range(splits):
part = kl[i*split_size:(i+1)*split_size]
split_scores.append(torch.exp(part.mean()))
return torch.mean(torch.tensor(split_scores)).item(), \
torch.std(torch.tensor(split_scores)).item()1.4 IS的局限性
| 问题 | 说明 |
|---|---|
| 对类内变化不敏感 | 无法检测类内模式崩溃 |
| 依赖Inception模型 | 领域迁移时可能失效 |
| 不比较真实分布 | 不评估与真实数据的相似度 |
| 可被对抗攻击欺骗 | 存在绕过IS的生成器 |
2. Fréchet Inception Distance(FID)
2.1 核心思想
FID2比较生成样本分布与真实样本分布在特征空间的差异:
其中:
- :真实/生成特征的均值
- :真实/生成特征的协方差矩阵
核心假设:Inception特征的分布近似为多元高斯分布。
2.2 直观理解
| 情况 | FID值 | 含义 |
|---|---|---|
| 完全相同分布 | 0 | 理想状态 |
| 分布偏移 | > 0 | 越大偏移越大 |
| 生成质量差 | 高 | 模糊/模式崩溃 |
2.3 实现
import torch
import numpy as np
from scipy import linalg
def calculate_fid(real_features, fake_features, eps=1e-6):
"""
计算FID
Args:
real_features: 真实样本特征 (N, D)
fake_features: 生成样本特征 (N, D)
Returns:
FID分数
"""
# 转为numpy
if isinstance(real_features, torch.Tensor):
real_features = real_features.cpu().numpy()
if isinstance(fake_features, torch.Tensor):
fake_features = fake_features.cpu().numpy()
# 计算均值
mu_real = np.mean(real_features, axis=0)
mu_fake = np.mean(fake_features, axis=0)
# 计算协方差
sigma_real = np.cov(real_features, rowvar=False)
sigma_fake = np.cov(fake_features, rowvar=False)
# FID公式
diff = mu_real - mu_fake
covmean, _ = linalg.sqrtm(sigma_real @ sigma_fake, disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
class FIDCalculator:
"""FID计算器"""
def __init__(self, device='cuda'):
self.device = device
self.inception = inception_v3(pretrained=True, transform_input=False)
self.inception.eval()
self.inception.to(device)
@torch.no_grad()
def extract_features(self, images, batch_size=50):
"""提取Inception特征"""
features = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size].to(self.device)
# 获取pool3层特征(2048维)
if hasattr(self.inception, 'pool3'):
pred = self.inception(batch)
feat = self.inception.pool3(pred)
else:
# Inception v3的中间层特征
x = self.inception.Conv2d_1a_3x3(batch)
x = self.inception.Conv2d_2a_3x3(x)
x = self.inception.Conv2d_2b_3x3(x)
feat = self.inception.maxpool1(F.relu(x))
# ... 更多层
feat = self.inception.pool5(F.relu(feat))
features.append(feat.cpu())
return torch.cat(features, dim=0).squeeze(-1).squeeze(-1)
def compute_fid(self, real_images, fake_images):
"""计算FID"""
real_feats = self.extract_features(real_images)
fake_feats = self.extract_features(fake_images)
return calculate_fid(real_feats.numpy(), fake_feats.numpy())2.4 FID的优势与局限
| 优势 | 局限 |
|---|---|
| 比较真实与生成分布 | 依赖预训练Inception |
| 对模式崩溃敏感 | 需要足够样本量 |
| 广泛使用和接受 | 假设高斯分布 |
| 连续可微 | 不捕获感知质量 |
3. Precision, Recall和F1
3.1 动机
IS和FID无法区分质量和多样性问题。Precision-Recall框架提供了更细粒度的评估。
3.2 定义
| 指标 | 定义 | 理想值 |
|---|---|---|
| Precision(精确率) | 生成样本中接近真实分布的比例 | 高 |
| Recall(召回率) | 真实样本中可被生成器生成的比例 | 高 |
| F1 | Precision和Recall的调和平均 | 高 |
3.3 实现
from sklearn.neighbors import NearestNeighbors
import torch
class PrecisionRecall:
"""基于密度的Precision和Recall"""
def __init__(self, k=5):
self.k = k # 近邻数量
def compute_pr(self, real_features, fake_features, manifold='real'):
"""
计算Precision和Recall
Args:
manifold: 'real' 或 'fake' 定义参考流形
"""
# 获取k近邻
if manifold == 'real':
reference = real_features
test = fake_features
else:
reference = fake_features
test = real_features
# 训练k-NN
nn = NearestNeighbors(n_neighbors=self.k + 1, metric='euclidean')
nn.fit(reference)
# 查找每个测试点的最近邻
distances, indices = nn.kneighbors(test)
if manifold == 'real':
# Precision: 假样本的邻居中有多少是真样本
precision = np.mean(indices[:, 1:] < len(real_features))
else:
# Recall: 真样本的邻居中有多少是假样本
# 需要重新计算
nn_fake = NearestNeighbors(n_neighbors=self.k, metric='euclidean')
nn_fake.fit(fake_features)
_, indices_fake = nn_fake.kneighbors(real_features)
recall = np.mean(indices_fake < len(fake_features))
return precision if manifold == 'real' else recall
class ModeScore:
"""Mode Score:改进的IS"""
def compute_mode_score(self, real_preds, fake_preds):
"""计算Mode Score"""
# p(y): 真实分布
py_real = real_preds.mean(dim=0)
# p(y|c): 条件分布
py_fake = fake_preds.mean(dim=0)
# 每个样本的条件分布
# KL(p(y|x) || p(y_c)) 的期望
kl_real = (real_preds * (torch.log(real_preds) - torch.log(py_real.unsqueeze(0)))).sum(dim=1)
kl_fake = (fake_preds * (torch.log(fake_preds) - torch.log(py_fake.unsqueeze(0)))).sum(dim=1)
mode_score = torch.exp((kl_real.mean() + kl_fake.mean()) / 2)
return mode_score.item()4. Perceptual Path Length(PPL)
4.1 核心思想
PPL3评估潜在空间的平滑性:
良好的潜在空间应该允许平滑插值,生成图像应随潜在向量线性变化。
4.2 定义
其中:
- :潜在空间插值路径长度
- :生成图像间的感知距离
4.3 实现
import lpips
class PerceptualPathLength:
"""感知路径长度"""
def __init__(self, model='net-lin', net='vgg', device='cuda'):
# LPIPS感知距离
self.loss_fn = lpips.LPIPS(net=net).to(device)
self.device = device
def compute_ppl(self, generator, num_samples=10000, epsilon=1e-4):
"""
计算PPL
Args:
generator: 生成器模型
num_samples: 采样数量
epsilon: 插值步长
"""
ppls = []
batch_size = 100
for i in range(0, num_samples, batch_size):
n = min(batch_size, num_samples - i)
# 采样两个潜在向量
z1 = torch.randn(n, generator.latent_dim, device=self.device)
z2 = torch.randn(n, generator.latent_dim, device=self.device)
# 线性插值
direction = z2 - z1
direction_norm = torch.norm(direction, dim=-1, keepdim=True)
direction = direction / (direction_norm + 1e-8)
# 生成端点和中间点
with torch.no_grad():
img1 = generator(z1)
img2 = generator(z2)
# 中间点
z_mid = z1 + 0.5 * direction * direction_norm
img_mid = generator(z_mid)
# 计算感知距离
dist = self.loss_fn(img1, img_mid).squeeze()
dist_end = self.loss_fn(img_mid, img2).squeeze()
# 路径长度
path_len = (dist + dist_end) / 2
ppls.append(path_len.mean().item())
return np.mean(ppls)5. 评估指标对比
| 指标 | 测量内容 | 范围 | 理想值 |
|---|---|---|---|
| IS | 质量+多样性 | 高 | |
| FID | 与真实分布距离 | 低(0最优) | |
| Precision | 生成质量 | 高 | |
| Recall | 模式覆盖 | 高 | |
| PPL | 潜在空间平滑性 | 低 |
6. 综合评估建议
6.1 推荐指标组合
def comprehensive_gan_eval(real_images, fake_images, generator):
"""
综合评估GAN
"""
results = {}
# 1. FID(主要指标)
fid_calc = FIDCalculator()
results['fid'] = fid_calc.compute_fid(real_images, fake_images)
# 2. IS
is_calc = InceptionScore()
results['is_mean'], results['is_std'] = is_calc.compute_is(fake_images)
# 3. Precision & Recall
pr_calc = PrecisionRecall()
real_feats = fid_calc.extract_features(real_images)
fake_feats = fid_calc.extract_features(fake_images)
results['precision'] = pr_calc.compute_pr(real_feats, fake_feats, manifold='real')
results['recall'] = pr_calc.compute_pr(real_feats, fake_feats, manifold='fake')
# 4. PPL
ppl_calc = PerceptualPathLength()
results['ppl'] = ppl_calc.compute_ppl(generator)
return results6.2 评估注意事项
| 注意事项 | 说明 |
|---|---|
| 样本数量 | FID建议至少5000对样本 |
| 图像预处理 | 保持真实和生成图像一致 |
| 评估重复性 | 多次评估取平均 |
| 综合指标 | 单一指标不足以评估 |
7. 2024年新进展:FID局限性
7.1 FID的批评
2024年CVPR的研究指出FID的若干局限性:
- 对预训练模型的依赖:Inception特征可能不是最佳选择
- 不检测特定伪影:局部失真可能被全局指标掩盖
- 感知与指标不一致:低FID不等于高感知质量
7.2 替代方案
| 方法 | 改进 |
|---|---|
| DISTS | 基于VGG的感知指标 |
| LPIPS | 学习感知相似度 |
| SSIM/MS-SSIM | 结构相似度 |
| Human Preference Score | 人类评估 |
8. 参考资料
扩展阅读:
Footnotes
-
Salimans T, Goodfellow I, Zaremba W, et al. Improved techniques for training GANs. NeurIPS, 2016. ↩
-
Heusel M, Ramsauer H, Unterthiner T, et al. GANs trained by a two time-scale update rule converge to a local Nash equilibrium. NeurIPS, 2017. ↩
-
Karras T, Laine S, Aila T, et al. A style-based generator architecture for generative adversarial networks. CVPR, 2019. ↩