GAN存在训练困难、训练得到的loss无法表示训练进程等问题。大多数的GAN都是从模型的结构上进行修改,如DCGAN,用卷积神经网络设计生成器和判别器,并进行了一些调整,但这些终究是治标不治本,没有从根本上解决问题。原始GAN是用JS散度来评判两个分布的相似程度,而JS散度存在一个问题,就是真实数据$P_{data}$和生成数据$P_{G}$没有数据重叠的话,不管真实数据和生成数据之间的距离多远,JS散度计算出来的值都是log2,这就导致了生成器梯度消失的问题。WGAN就是用来解决这个问题,其优点有以下几个方面:
- 彻底解决GAN训练不稳定和mode collapse的问题,不再需要小心平衡生成器和判别器的训练程度
- 不需要精心设计网络结构,MLP网络就能实现很好的结果
- WGAN定义的损失函数可以反映训练的进程,从而我们可以根据这个指标来判断模型训练的效果
WGAN实现以上好处使用的技术有:
- 损失函数不取log
- 每次更新判别器的参数都截断到-c到c之间
- 使用RMSProp参数优化算法,而不是Adam
- 判别器的最后一层去掉sigmoid
模型实现
具体为什么要这么做,请查看原论文,本文主要关注的是WGAN模型以及训练过程Pytorch代码实现。
首先来看模型部分,这里使用的是DCGAN的模型,只不过将判别器的最后一层的sigmoid去掉

import torch
import torch.nn as nn
# 模型参数初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02) # 均值为0,标准差为0.02的正态分布
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0) # 均值为1,标准差为0.02的正态分布
# 生成器
class Generator(nn.Module):
"""
input (N, in_dim)
output (N, 3, 64, 64)
"""
def __init__(self, in_dim, dim=64):
super(Generator, self).__init__()
def dconv_bn_relu(in_dim, out_dim):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 3, 2,
padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
self.l1 = nn.Sequential(
nn.Linear(in_dim, dim * 16 * 4 * 4, bias=False),
nn.BatchNorm1d(dim * 16 * 4 * 4),
nn.ReLU()
)
self.l2_5 = nn.Sequential(
dconv_bn_relu(dim * 16, dim * 8), # (N, 512, 8, 8)
dconv_bn_relu(dim * 8, dim * 4), # (N, 256, 16, 16)
dconv_bn_relu(dim * 4, dim * 2), # (N, 128, 32, 32)
nn.ConvTranspose2d(dim, 3, 3, 2, padding=1, output_padding=1),
nn.Tanh() # (N, 3, 64, 64)
)
self.apply(weights_init)
def forward(self, x):
out = self.l1(x) # (N, 1024 * 4 * 4)
out = out.view(out.size(0), -1, 4, 4) # (N, 1024, 4, 4)
out = self.l2_5(out) # (N, 3, 64, 64)
return out
# 判别器
class Discriminator(nn.Module):
"""
input (N, 3, 64, 64)
output (N, )
"""
def __init__(self, in_dim, dim=64):
super(Discriminator, self).__init__()
def conv_bn_lrelu(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, 2, 1),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2)
)
self.ls = nn.Sequential(
nn.Conv2d(in_dim, dim * 2, 3, 2, 1), nn.LeakyReLU(0.2), # (N, 128, 32, 32)
conv_bn_lrelu(dim * 2, dim * 4), # (N, 256, 16, 16)
conv_bn_lrelu(dim * 4, dim * 8), # (N, 512, 8, 8)
conv_bn_lrelu(dim * 8, dim * 16), # (N, 1024, 4, 4)
nn.Conv2d(dim * 16, 1, 4), # (N, 1)
)
self.apply(weights_init)
def forward(self, x):
y = self.ls(x)
y = y.view(-1)
return y
这个部分与前面笔记HW6-GAN生成器和判别器的模型基本一样,有关数据处理可以看这个笔记,下面我们主要看WGAN的训练过程,可以看前面模型的训练过程并进行对比。
模型训练
准备好model和optimizer
# 设置超参数
batch_size = 64
z_dim = 100
lr = 0.0001
n_epoch = 10
cliping = 0.02
critic_iter = 1
# model
G = Generator(in_dim=z_dim).cuda()
D = Discriminator(in_dim=3).cuda()
G.train()
D.train()
# optimizer,注意这里用的是RMSprop
opt_D = torch.optim.RMSprop(D.parameters(), lr=lr)
opt_G = torch.optim.RMSprop(G.parameters(), lr=lr)
# 用于后向传播更新模型参数
one = torch.FloatTensor([1])
mone = one * -1
开始训练
for epoch in range(n_epoch):
for i, data in enumerate(dataloader):
imgs = data
imgs = imgs.cuda()
bs = imgs.size(0) # 这里不用batch_size是为了防止最后一个并不是
"""Train D"""
# Requires grad, Generator requires_grad = False
for p in D.parameters():
p.requires_grad = True
for d_iter in range(critic_iter):
D.zero_grad()
# Clamp parameters
for p in D.parameters():
p.data.clamp_(-cliping, cliping)
z = Variable(torch.randn(bs, z_dim)).cuda()
r_imgs = Variable(imgs).cuda()
f_imgs = G(z)
# Train with real images
d_loss_real = D(r_imgs).mean().view(1)
d_loss_real.backward(mone.cuda())
# Train with fake images
d_loss_fake = D(f_imgs).mean().view(1)
d_loss_fake.backward(one.cuda())
d_loss = d_loss_fake - d_loss_real
Wasserstein_D = d_loss_real - d_loss_fake
opt_D.step()
""" Train G"""
for p in D.parameters():
p.requires_grad = False # to avoid computation
G.zero_grad()
z = Variable(torch.randn(bs, z_dim)).cuda()
f_imgs = G(z)
g_loss = D(f_imgs).mean().view(1)
g_loss.backward(mone.cuda())
g_cost = -g_loss
opt_G.step()
# 打印当前模型训练的状态
print(f'\rEpoch [{epoch + 1}/{n_epoch}] {i + 1}/{len(dataloader)} '
f'Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}', end='')
下面展示一张我训练迭代十次后出来的照片,整体看起来还是不错的,这里我没有使用原始DCGAN的模型,而是用了它的简化版,即前面笔记HW6-GAN中的模型,只是将判别器最后一层的sigmoid去掉,这个视觉上比HW6-GAN第10次迭代的效果要好些,但其中还是有一些瑕疵,个人觉得是我迭代的次数太少,也可能是代码问题。
