本篇再详细讲一下如何用训练好的GAN网络生成和挑选图片。
本例中的数据集是Pretty Face,链接如下:
https://www.kaggle.com/datasets/yewtsing/pretty-face.
数据集中共有3千多图片,绝大多数是亚洲美女的人脸图片,也有一些人头发有点秃或者脸上有色斑,还有极少数男人面孔(为了好训练建议删掉)。
假设生成网络和判别网络已经训练好,并已保存。
代码语言:javascript复制torch.save(netG.state_dict(), "Pretty_face_128x128_netG.pth")
torch.save(netD.state_dict(), "Pretty_face_128x128_netD.pth")
首先是复制一下网络参数常数和网络结构,也可以从训练程序import。
代码语言:javascript复制
代码语言:javascript复制import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
# nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
# 此时 Hin == Wout ==1
nn.ConvTranspose2d(in_channels=nz, out_channels=ngf*8, kernel_size=(4, 4), stride=(1, 1),
padding=(0, 0), output_padding=(0, 0), bias=False),
# ConvTranspose2d 后 新的 长和 宽:
# Hout = (Hin - 1 ) x stride[0] - 2 x padding[0] Kernel_size[0] output_padding[0]
# Wout = (Win - 1 ) x stride[1] - 2 x padding[1] Kernel_size[1] output_padding[1]
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
# nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=(4, 4), stride=(2, 2),
padding=(1, 1), output_padding=(0, 0), bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 64 x 64
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 128 x 128
)
def forward(self, input_):
return self.main(input_)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# input is (nc) x 128 x 128
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 64 x 64
nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input_):
return self.main(input_)
real_label = 1.
fake_label = 0.
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
nz = 100
# Size of feature maps in generator
# ngf = 64
ngf = 64
ndf = 64
nc = 3
nz = 100
接着导入训练好的生成网络和判别网络
代码语言:javascript复制netG = Generator().to(device)netG.load_state_dict(torch.load("Pretty_face_128x128_netG.pth"))
netD = Discriminator().to(device)
netD.load_state_dict(torch.load("Pretty_face_128x128_netD.pth"))
设置随机种子,这步不是必须的。
代码语言:javascript复制# Set random seed for reproducibilitymanualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
代码语言:javascript复制
最后循环生成伪造的图片,并利用判别网络挑选稍微真一点的图片保存。
代码语言:javascript复制while True:
if i >1000: break # 只伪造1000张
# 利用生成网络由尺寸为 1x100x1x1 的张量生成 伪造图片
fixed_noise = torch.randn(1, nz, 1, 1, device=device)
fake0 = netG(fixed_noise.detach())
fake = fake0.cpu() # GPU 张量复制到CPU
#判别网络给伪造图片打分(0.0~1.0 )
netD_output = netD(fake0).item()
if netD_output > 0.2: # 只挑选分数稍微高一点的保存到磁盘
print(i, netD_output)
img0 = vutils.make_grid(fake, padding=2, normalize=True)
img0 = img0.numpy()
img = np.transpose(img0, (1, 2, 0))
#plt.imshow(img)
plt.imsave(f"Pretty_128x128\fake{i}.png", img)
plt.close("None")
i = 1
程序生成和挑选出来的伪造图如下,大部分还是很像美女图片
有不少伪造的图片脸上有色斑,甚至有秃头,这些是从训练集带来的特征,如果不喜欢这些特征可以在训练之前从数据集中删除相应的图片。