WGAN的Pytorch实现


GAN存在训练困难、训练得到的loss无法表示训练进程等问题。大多数的GAN都是从模型的结构上进行修改,如DCGAN,用卷积神经网络设计生成器和判别器,并进行了一些调整,但这些终究是治标不治本,没有从根本上解决问题。原始GAN是用JS散度来评判两个分布的相似程度,而JS散度存在一个问题,就是真实数据$P_{data}$和生成数据$P_{G}$没有数据重叠的话,不管真实数据和生成数据之间的距离多远,JS散度计算出来的值都是log2,这就导致了生成器梯度消失的问题。WGAN就是用来解决这个问题,其优点有以下几个方面:

  • 彻底解决GAN训练不稳定和mode collapse的问题,不再需要小心平衡生成器和判别器的训练程度
  • 不需要精心设计网络结构,MLP网络就能实现很好的结果
  • WGAN定义的损失函数可以反映训练的进程,从而我们可以根据这个指标来判断模型训练的效果

WGAN实现以上好处使用的技术有:

  • 损失函数不取log
  • 每次更新判别器的参数都截断到-c到c之间
  • 使用RMSProp参数优化算法,而不是Adam
  • 判别器的最后一层去掉sigmoid

模型实现

具体为什么要这么做,请查看原论文,本文主要关注的是WGAN模型以及训练过程Pytorch代码实现。

首先来看模型部分,这里使用的是DCGAN的模型,只不过将判别器的最后一层的sigmoid去掉

import torch
import torch.nn as nn

# 模型参数初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)  # 均值为0,标准差为0.02的正态分布
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)             # 均值为1,标准差为0.02的正态分布

# 生成器
class Generator(nn.Module):
    """
    input (N, in_dim)
    output (N, 3, 64, 64)
    """
    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()
        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 3, 2,
                                   padding=1, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 16 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 16 * 4 * 4),
            nn.ReLU()
        )
        
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 16, dim * 8),        # (N, 512, 8, 8)  
            dconv_bn_relu(dim * 8, dim * 4),         # (N, 256, 16, 16)
            dconv_bn_relu(dim * 4, dim * 2),         # (N, 128, 32, 32)
            nn.ConvTranspose2d(dim, 3, 3, 2, padding=1, output_padding=1), 
            nn.Tanh()                                # (N, 3, 64, 64)
        )
        self.apply(weights_init)
        
	def forward(self, x):
        out = self.l1(x)                        # (N, 1024 * 4 * 4)
        out = out.view(out.size(0), -1, 4, 4)   # (N, 1024, 4, 4)
        out = self.l2_5(out)                    # (N, 3, 64, 64)
        return out

# 判别器
class Discriminator(nn.Module):
    """
        input (N, 3, 64, 64)
        output (N, )
        """
    def __init__(self, in_dim, dim=64):
        super(Discriminator, self).__init__()
        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 3, 2, 1),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2)
            )
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim * 2, 3, 2, 1), nn.LeakyReLU(0.2), # (N, 128, 32, 32)
            conv_bn_lrelu(dim * 2, dim * 4),                        # (N, 256, 16, 16)
            conv_bn_lrelu(dim * 4, dim * 8),                        # (N, 512, 8, 8)
            conv_bn_lrelu(dim * 8, dim * 16),                       # (N, 1024, 4, 4)
            nn.Conv2d(dim * 16, 1, 4),                              # (N, 1)
        )
        self.apply(weights_init)

    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y

这个部分与前面笔记HW6-GAN生成器和判别器的模型基本一样,有关数据处理可以看这个笔记,下面我们主要看WGAN的训练过程,可以看前面模型的训练过程并进行对比。

模型训练

准备好model和optimizer

# 设置超参数
batch_size = 64
z_dim = 100
lr = 0.0001
n_epoch = 10
cliping = 0.02
critic_iter = 1

# model
G = Generator(in_dim=z_dim).cuda()
D = Discriminator(in_dim=3).cuda()
G.train()
D.train()

# optimizer,注意这里用的是RMSprop
opt_D = torch.optim.RMSprop(D.parameters(), lr=lr)
opt_G = torch.optim.RMSprop(G.parameters(), lr=lr)

# 用于后向传播更新模型参数
one = torch.FloatTensor([1])
mone = one * -1

开始训练

for epoch in range(n_epoch):
    for i, data in enumerate(dataloader):
        imgs = data
        imgs = imgs.cuda()

        bs = imgs.size(0)     # 这里不用batch_size是为了防止最后一个并不是

        """Train D"""
        # Requires grad, Generator requires_grad = False
        for p in D.parameters():
            p.requires_grad = True

        for d_iter in range(critic_iter):
            D.zero_grad()
            # Clamp parameters
            for p in D.parameters():
                p.data.clamp_(-cliping, cliping)

            z = Variable(torch.randn(bs, z_dim)).cuda()
            r_imgs = Variable(imgs).cuda()
            f_imgs = G(z)

            # Train with real images
            d_loss_real = D(r_imgs).mean().view(1)
            d_loss_real.backward(mone.cuda())

            # Train with fake images
            d_loss_fake = D(f_imgs).mean().view(1)
            d_loss_fake.backward(one.cuda())

            d_loss = d_loss_fake - d_loss_real
            Wasserstein_D = d_loss_real - d_loss_fake
            opt_D.step()

        """ Train G"""
        for p in D.parameters():
            p.requires_grad = False  # to avoid computation

        G.zero_grad()
        z = Variable(torch.randn(bs, z_dim)).cuda()
        f_imgs = G(z)

        g_loss = D(f_imgs).mean().view(1)
        g_loss.backward(mone.cuda())
        g_cost = -g_loss
        opt_G.step()

        # 打印当前模型训练的状态
        print(f'\rEpoch [{epoch + 1}/{n_epoch}] {i + 1}/{len(dataloader)} '
              f'Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}', end='')

下面展示一张我训练迭代十次后出来的照片,整体看起来还是不错的,这里我没有使用原始DCGAN的模型,而是用了它的简化版,即前面笔记HW6-GAN中的模型,只是将判别器最后一层的sigmoid去掉,这个视觉上比HW6-GAN第10次迭代的效果要好些,但其中还是有一些瑕疵,个人觉得是我迭代的次数太少,也可能是代码问题。


文章作者: 不才叶某
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 不才叶某 !
评论
  目录