【论文复现】Conditional Generative Adversarial Nets(CGAN)

2024-07-30 12:26:18 浏览数 (2)

GAN基础理论

具体内容详见:【论文复现】Generative Adversarial Nets(GAN基础理论)

2.1 算法来源

作者:Mehdi Mirza, Simon Osindero 摘要:   Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels. 日期:6 Nov 2014 论文链接: https://arxiv.org/pdf/1411.1784.pdf 实验数据: https://github.com/MrHeadbang/machineLearning/blob/main/mnist.zip 代码链接: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py

2.2 算法介绍

  CGAN(条件式生成对抗网络)是对原始GAN的一种变形,其生成器和判别器都增加额外信息C作为条件条件可以是类别信息、或其他模态数据。通过将额外信息C输送给判别模型和生成模型,作为输入层的一部分,其架构图如下:

  和原始GAN一样,CGAN还是基于多层感知器。在原始GAN中,判别器的输入是训练样本x,生成器的的输入是噪声z,而在CGAN中,生成器和判别器的输入都多了一个y,这个y就是那个额外条件信息。

  • 把噪声z和条件y作为输入同时送进生成器生成跨域向量,再通过非线性函数映射到数据空间。
  • 把数据x和条件y作为输入同时送进判别器生成跨域向量,并进一步判断x是真实训练数据的概率。
  • 二元极小极大博弈转变为:

2.3 基于CGAN的手写数字生成实验

2.3.1 网络结构

原始的GAN是无监督的,包括之前实验课上的DCGAN,其输出是完全随机的,在人脸上训练好的网络,最后生成什么样的人脸是完全没办法控制的。而CGAN则是有监督的GAN,在MNIST上以数字类别标签为约束条件,最终根据类别标签信息,生成对应的数字图像

  本实验使用MNIST(手写数字体)数据集,生成器的输入是100维服从均匀分布的噪声向量,以类别标签(one-hot编码)为条件来训练CGAN,生成器经过sigmoid生成784维(28x28)的单通道图像(每张图片的shape是[1, 28, 28]),判别器的输入为784维的图像和类别标签(one-hot编码),输出是该样本来自训练集的的概率。

2.3.2 训练过程

  CGAN的损失函数即BCELoss:及Adam优化器

代码语言:javascript复制
# Define the Binary Cross Entropy Loss criterion for the GAN
criterion = nn.BCELoss()

# Set up optimizers for the discriminator and generator models
# Use Adam optimizer for updating discriminator's parameters
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
# Use Adam optimizer for updating generator's parameters
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

  CGAN中有三个loss,一个是D的real概率的loss,一个是D的fake概率的loss(二者相加得到d_loss),最后是G的real的loss。

代码语言:javascript复制
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch), end=' ')
    
    # Iterate through the data loader
    for i, (images, labels) in enumerate(data_loader):
        step = epoch * len(data_loader)   i   1
        real_images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        generator.train()
        
		d_loss = 0
        # Perform multiple discriminator training steps
        for _ in range(n_critic):
            d_loss = discriminator_train_step(len(real_images), discriminator,
                                              generator, d_optimizer, criterion,
                                              real_images, labels,
                                              device)
        
        # Perform a single generator training step
        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device)
        
        # Write the losses to TensorBoard
        writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / n_critic)}, step)  

【深度学习实验】TensorBoard使用教程【SCALARS、IMAGES、TIME SERIES】

一、 D的loss (discriminator_train_step)
  • D的real概率的loss
代码语言:javascript复制
  	# Train the discriminator with real images
    real_validity = discriminator(real_images, labels)
    # Calculate loss on real images; discriminator's goal: classify real images as real (1)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))

输入的是真实的MNIST数据集的图像,是real

  • D的fake概率的loss
代码语言:javascript复制
	# Train the discriminator with fake images
    z = Variable(torch.randn(batch_size, 100)).to(device)
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    # Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))

输入的是G生成的假的图像,是fake,要让判别器知道

  • D的total loss
代码语言:javascript复制
    # Total discriminator loss is the sum of losses on real and fake images
    d_loss = real_loss   fake_loss

    # Backpropagation: Compute gradients and update discriminator's weights
    d_loss.backward()

整合:

代码语言:javascript复制
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels, device):
    # Zero out the gradients from the previous iteration
    d_optimizer.zero_grad()

    # Train the discriminator with real images
    real_validity = discriminator(real_images, labels)
    # Calculate loss on real images; discriminator's goal: classify real images as real (1)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))

    # Train the discriminator with fake images
    z = Variable(torch.randn(batch_size, 100)).to(device)
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    # Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))

    # Total discriminator loss is the sum of losses on real and fake images
    d_loss = real_loss   fake_loss

    # Backpropagation: Compute gradients and update discriminator's weights
    d_loss.backward()
    d_optimizer.step()

    # Return the discriminator's loss as a Python float
    return d_loss.item()
二、 G的loss (generator_train_step)

  生成器要骗过判别器,生成较为逼真的图像。怎么骗判别器?那就是在做一个real的loss,用的还是G生成的图像数据。

代码语言:javascript复制
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device):
    # Zero out the gradients from the previous iteration
    g_optimizer.zero_grad()

    # Generate random noise vector z
    z = Variable(torch.randn(batch_size, 100)).to(device)

    # Generate random labels for the fake images
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)

    # Generate fake images using the generator
    fake_images = generator(z, fake_labels)

    # Get the discriminator's prediction on the generated fake images
    validity = discriminator(fake_images, fake_labels)

    # Calculate the generator's loss using the discriminator's prediction
    # Generator's goal: Make the discriminator classify generated images as real (1)
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).to(device))

    # Backpropagation: Compute gradients and update generator's weights
    g_loss.backward()
    g_optimizer.step()

    # Return the generator's loss as a Python float
    return g_loss.item()

2.4 实验分析

2.4.1 超参数调整

具体内容详见:【论文复现】基于CGAN的手写数字生成实验——超参数调整

一、batch size
二、 epochs
三、 Adam:learning rate
四、 Adam:weight_decay
五、 n_critic

2.4.2 模型改进

具体内容详见:【论文复现】基于CGAN的手写数字生成实验——模型改进

一、 超参数优化
二、 逐层归一化
三、 损失函数改进
四、 激活函数选择
五、 优化器改进
六、 噪声z的分布
七、 其余设想

2.4.3 模型测试

Batch Normalization PReLU激活函数 AdamW优化器

0 人点赞