GAN的基本原理
GAN的基本原理其实非常简单,它包含两个网络,G网络(Generator)和D网络(Discriminator)。G网络的目标是尽量生成真实的图片去欺骗判别网络D,D网络的目标是尽量把G网络生成的图片和真实的图片分别开来。
最理想的结束状态是,G网络可以生成足以“以假乱真”的图片,而D网络,它难以判定G生成的图片究竟是不是真实的。
图片来源【1】
先看以下枯燥的数学语言描述下GAN的核心原理:
上述公式中:x表示真实图片,z表示输入G网络的随机噪声,而G(z)表示G网络生成的图片;D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。
D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”,也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小,因此对于G来说就是求最小的G(min_G)。
D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此,对于D来说是求最大D(max_D)。
下面实现一个DCGAN生成二次元图像的例子。先在我的渣渣笔记本上的训练效果。
笔记本训练比较慢,所以只用了1000张图片作为训练输入数据,训练了50个epoch,不过可以看出已经有初步的效果了。
GAN二次元头像数据集
Tensorflow的官网Demo中使用的MNIST数据集,这里我们换一个数据集kaggle——Anime Faces,里面有21551张动漫头像的图片。数据链接如下:
kaggle——Anime Faces
部分图片如下:
加载数据集(Dataset)
引入必要的Python头文件。
代码语言:javascript复制import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
加载数据,构造Tensorflow数据集,同时将图片的像素数值缩放到[-1, 1]之间。
代码语言:javascript复制def load_data():
all_images = []
# max_count = 1000
# count = 0
for dirname, _, filenames in os.walk('/GAN/archive/data/'):
for filename in filenames:
image = imageio.imread(os.path.join(dirname, filename))
all_images.append(image)
# count = count 1
# if count > max_count:
# break
all_images = np.array(all_images)
all_images = (all_images - 127.5) / 127.5
return all_images
train_images = load_data()
BUFFER_SIZE = 3000
BATCH_SIZE = 10
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
先看下数据集(Dataset)中数据。
定义模型(Model)
Generator Model
Generator使用tf.keras.layers.Conv2DTran spose进行上采样(upsampling)将随机噪声(Random Noise)生成64x64x3的图像数据。
代码语言:javascript复制def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(4 * 4 * 1024, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((4, 4, 1024)))
assert model.output_shape == (None, 4, 4, 1024) # 注意:batch size 没有限制
model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 8, 8, 512)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 16, 16, 256)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 32, 32, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 64, 64, 3)
return model
使用未经训练的Generator模型生成一张图像看看效果:
代码语言:javascript复制generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
pred_img = (generated_image[0, :, :, :] 1.0) / 2.0
plt.imshow(pred_img)
plt.axis('off')
plt.show()
未训练的Generator生成图像如下:
Discriminator Model
Discriminator是一个基于CNN的分类器(classifier)。
代码语言:javascript复制def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same',
input_shape=[64, 64, 3]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(1024, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
我们使用未经过训练的discriminator来判断生成的图像真假。
代码语言:javascript复制discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
输出:
代码语言:javascript复制tf.Tensor([[0.00011664]], shape=(1, 1), dtype=float32)
定义Loss函数和优化器
代码语言:javascript复制# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
Discriminator loss
Discriminator Loss用来衡量discriminator能够区分图像真假的能力。
代码语言:javascript复制def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss fake_loss
return total_loss
Generator Loss
Generator Loss用来衡量Generator欺骗discriminator的能力。
代码语言:javascript复制def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
Discriminator和Generator由于两个不同独立网络,所以定义了两个不同的Optimizer。
代码语言:javascript复制generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
定义Train Loop
代码语言:javascript复制EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])
Training Loop中Generator将随机数(random seed)生成图像,Discriminator用来区分生成图像的真假。
代码语言:javascript复制# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
代码语言:javascript复制def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch 1,
seed)
# Save the model every 15 epochs
if (epoch 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
保存和生成图像
代码语言:javascript复制def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
训练模型
调用train()方法同时训练generator和discriminator。训练开始时,generator生成的图片看起来像是随机噪声,随着训练的进行,生成的图像越来越真实。
代码语言:javascript复制train(train_dataset, EPOCHS)
epoch 1
epoch 10
epoch 20
epoch 30
epoch 40
epoch 50
生成GIF图片
最后把训练过程中保存的图片合并起来,生成一副gif图片,这样可以直接的看GAN网络的训练过程。
代码语言:javascript复制anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
效果如下: