GAN应用场景
生成对抗网络在计算机视觉和机器学习的各个领域都有广泛应用。本文档介绍GAN的主要应用场景及其代表性方法。
1. 图像生成
1.1 人脸生成
StyleGAN系列是高质量人脸生成的代表性工作:
| 模型 | 分辨率 | 特点 |
|---|---|---|
| StyleGAN | 1024² | 风格化生成 |
| StyleGAN2 | 1024² | 去除伪影 |
| StyleGAN3 | 1024² | 消除纹理粘连 |
| StyleGAN-XL | 1024² | 超大规模 |
应用:
- 虚拟人物创建
- 游戏角色生成
- 艺术创作
1.2 场景生成
BigGAN在ImageNet场景生成上取得突破:
# BigGAN条件生成示例
def generate_scene(generator, class_id, z=None):
"""生成特定类别的场景"""
if z is None:
z = torch.randn(1, generator.latent_dim)
# 注入类别信息
y = torch.zeros(1, generator.num_classes)
y[0, class_id] = 1
# 生成
with torch.no_grad():
img = generator(z, y)
return img1.3 艺术生成
GAN在艺术创作中的应用:
| 应用 | 方法 | 代表工作 |
|---|---|---|
| 风格迁移 | cGAN | GauGAN |
| 画作生成 | StyleGAN | Artbreeder |
| 3D生成 | 3D-GAN | DreamFusion |
2. 图像到图像转换
2.1 Pix2Pix
Pix2Pix1使用条件GAN进行配对的图像转换:
class Pix2PixGenerator(nn.Module):
"""Pix2Pix U-Net风格生成器"""
def __init__(self, in_channels=3, out_channels=3):
super().__init__()
# 编码器
self.enc1 = self._conv_block(in_channels, 64)
self.enc2 = self._conv_block(64, 128)
self.enc3 = self._conv_block(128, 256)
self.enc4 = self._conv_block(256, 512)
# 解码器(带跳跃连接)
self.dec4 = self._upconv_block(512, 256)
self.dec3 = self._upconv_block(512, 128)
self.dec2 = self._upconv_block(256, 64)
self.dec1 = self._upconv_block(128, out_channels)
def forward(self, x):
# 编码
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
# 解码(带跳跃连接)
d4 = self.dec4(e4)
d4 = torch.cat([d4, e3], dim=1)
d3 = self.dec3(d4)
d3 = torch.cat([d3, e2], dim=1)
d2 = self.dec2(d3)
d2 = torch.cat([d2, e1], dim=1)
return torch.tanh(self.dec1(d2))应用场景:
- 卫星图→地图
- 素描→照片
- 白天→夜景
2.2 CycleGAN
CycleGAN2实现无配对的图像转换:
循环一致性损失:
class CycleGAN:
"""
CycleGAN: Unpaired Image-to-Image Translation
"""
def __init__(self, G_F, G_G, D_X, D_Y):
# G_F: X → Y
# G_G: Y → X
# D_X: 区分真实X和生成X
# D_Y: 区分真实Y和生成Y
def compute_losses(self, x, y):
"""计算所有损失"""
# 对抗损失
loss_G_F = self.gan_loss(self.D_Y(self.G_F(x)), True)
loss_G_G = self.gan_loss(self.D_X(self.G_G(y)), True)
# 循环一致性损失
x_recon = self.G_G(self.G_F(x))
loss_cycle_x = self.recon_loss(x_recon, x)
y_recon = self.G_F(self.G_G(y))
loss_cycle_y = self.recon_loss(y_recon, y)
# 身份损失(可选)
loss_id_x = self.recon_loss(self.G_G(x), x)
loss_id_y = self.recon_loss(self.G_F(y), y)
return {
'G_F': loss_G_F + loss_cycle_x + 0.5 * loss_id_y,
'G_G': loss_G_G + loss_cycle_y + 0.5 * loss_id_x,
'D_X': self.discriminator_loss(self.D_X, x, self.G_G(y)),
'D_Y': self.discriminator_loss(self.D_Y, y, self.G_F(x))
}应用场景:
- 马→斑马
- 夏季→冬季
- 照片→油画
2.3 StarGAN
StarGAN3实现多域之间的统一转换:
class StarGAN:
"""
StarGAN: Multi-Domain Image-to-Image Translation
"""
def __init__(self):
self.generator = StarGANGenerator()
self.discriminator = StarGANDiscriminator()
def forward(self, img, target_domain):
"""前向传播"""
# 生成目标域图像
fake_img = self.generator(img, target_domain)
# 重构回原始图像
recon_img = self.generator(fake_img, original_domain)
return fake_img, recon_img优势:
- 一个模型支持多域转换
- 可控的属性变换
- 节约计算资源
3. 数据增强
3.1 类别不平衡问题
GAN可以生成少数类样本:
class GANAugmentation:
"""基于GAN的数据增强"""
def __init__(self, minority_class, num_samples):
self.minority_class = minority_class
self.num_samples = num_samples
self.generator = None
def train(self, dataloader):
"""训练生成器"""
# 在少数类数据上训练GAN
for epoch in range(100):
for batch in dataloader:
# 训练判别器和生成器
pass
def augment(self, original_dataset):
"""生成新样本"""
with torch.no_grad():
z = torch.randn(self.num_samples, self.latent_dim)
synthetic_samples = self.generator(z)
# 添加到数据集
augmented_dataset = original_dataset + synthetic_samples
return augmented_dataset3.2 医学影像增强
GAN在医学影像领域的应用:
| 领域 | 应用 | 方法 |
|---|---|---|
| CT/MRI | 合成影像 | CT-SGAN |
| 病理切片 | 细胞生成 | Pathology-GAN |
| 眼底 | 病变模拟 | Retinal-GAN |
| X光 | 缺陷生成 | Chest-Xray-GAN |
3.3 对抗样本增强
class AdvGAN:
"""对抗GAN:生成对抗样本"""
def __init__(self, target_model):
self.target_model = target_model
self.G = Generator()
self.D = Discriminator()
def craft_adversarial(self, x, target_label):
"""
生成对抗样本
"""
# 生成扰动
noise = self.G(x)
# 添加扰动
x_adv = x + noise
# 确保攻击成功
pred = self.target_model(x_adv)
return x_adv4. 图像编辑与操作
4.1 潜在空间编辑
StyleGAN的潜在空间具有良好的语义结构:
class LatentSpaceEditor:
"""潜在空间编辑"""
def __init__(self, generator):
self.G = generator
def edit_attribute(self, z, direction, alpha):
"""
在潜在空间中沿某个方向移动来编辑属性
Args:
z: 基础潜在向量
direction: 属性方向(如"微笑"方向)
alpha: 编辑强度
"""
return z + alpha * direction
def interpolate(self, z1, z2, steps=10):
"""潜在空间插值"""
alphas = torch.linspace(0, 1, steps)
interpolations = []
for alpha in alphas:
z = (1 - alpha) * z1 + alpha * z2
interpolations.append(z)
return interpolations
def apply_style_mixing(self, content_z, style_z):
"""
风格混合:使用内容z的粗粒度特征和风格z的细粒度特征
"""
# 混合不同层级的风格
styles = {}
for i, (c_feat, s_feat) in enumerate(zip(content_features, style_features)):
if i < 4: # 粗粒度层级
styles[i] = c_feat
else: # 细粒度层级
styles[i] = s_feat
return styles4.2 图像修复(Inpainting)
class GANInpainting:
"""基于GAN的图像修复"""
def __init__(self):
self.generator = InpaintingGenerator()
self.discriminator = InpaintingDiscriminator()
def inpaint(self, image, mask):
"""
修复图像
Args:
image: 原始图像
mask: 缺失区域掩码(1=缺失)
"""
# 填充缺失区域为0或平均值
masked_image = image * (1 - mask)
# 生成修复
with torch.no_grad():
completed = self.generator(masked_image, mask)
# 混合修复区域和原始区域
result = image * (1 - mask) + completed * mask
return result4.3 超分辨率(SRGAN)
SRGAN4将GAN应用于图像超分辨率:
class SRGANGenerator(nn.Module):
"""SRGAN生成器"""
def __init__(self, scale_factor=4):
super().__init__()
# 特征提取
self.conv_first = nn.Conv2d(3, 64, 9, padding=4)
# 残差块
self.residual_blocks = nn.Sequential(*[
ResidualBlock(64) for _ in range(16)
])
# 后处理
self.conv_second = nn.Conv2d(64, 64, 3, padding=1)
self.upsample = nn.Sequential(
nn.Conv2d(64, 256, 3, padding=1),
nn.PixelShuffle(scale_factor),
nn.PReLU()
)
self.conv_last = nn.Conv2d(64, 3, 9, padding=4)
def forward(self, x):
feat = F.relu(self.conv_first(x))
feat = self.residual_blocks(feat)
feat = self.conv_second(feat)
feat = self.upsample(feat)
return torch.tanh(self.conv_last(feat))5. 其他应用
5.1 文本生成
| 方法 | 说明 | 局限 |
|---|---|---|
| GAN-GAN | GAN直接生成离散文本 | 梯度不可导 |
| SeqGAN | 强化学习+GAN | 训练困难 |
| TextGAN | 使用CNN生成文本 | 效果有限 |
5.2 语音合成
| 应用 | 方法 | 代表工作 |
|---|---|---|
| 语音生成 | WaveGAN | Parallel WaveGAN |
| 语音转换 | CycleGAN-VC | StarGAN-VC |
| 说话人合成 | GE2E-GAN | - |
5.3 图生成
class GRANGenerator(nn.Module):
"""Graph Random Neural Network GAN for Graph Generation"""
def __init__(self):
self.G = GraphGenerator()
self.D = GraphDiscriminator()
def generate_graph(self, num_nodes):
"""生成分子图"""
adj = self.G(num_nodes)
return adj6. 实际应用案例
6.1 电商图像生成
应用流程:
1. 收集商品图片
2. 训练StyleGAN生成新商品
3. 应用于商品展示、虚拟试穿
6.2 游戏资产生成
应用流程:
1. 定义资产类型(角色、道具、场景)
2. 训练条件GAN
3. 按需生成新资产
6.3 视频生成
| 方法 | 说明 |
|---|---|
| VID2VID | 视频到视频转换 |
| MoCoGAN | 运动内容分离 |
| DVD-GAN | 视频生成 |
7. 参考资料
扩展阅读:
Footnotes
-
Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks. CVPR, 2017. ↩
-
Zhu J Y, Park T, Isola P, et al. Unpaired image-to-image translation using cycle-consistent adversarial networks. ICCV, 2017. ↩
-
Choi Y, Uh M R, Yoo J, et al. Stargan v2: Diverse image synthesis for multiple domains. CVPR, 2020. ↩
-
Ledig C, Theis L, Huszár F, et al. Photo-realistic single image super-resolution using a generative adversarial network. CVPR, 2017. ↩