用训练好的生成对抗网络生成和挑选图片

2022-11-18 13:59:33 浏览数 (3)

本篇再详细讲一下如何用训练好的GAN网络生成和挑选图片。

本例中的数据集是Pretty Face,链接如下:

https://www.kaggle.com/datasets/yewtsing/pretty-face.

数据集中共有3千多图片,绝大多数是亚洲美女的人脸图片,也有一些人头发有点秃或者脸上有色斑,还有极少数男人面孔(为了好训练建议删掉)。

假设生成网络和判别网络已经训练好,并已保存。

代码语言:javascript复制
torch.save(netG.state_dict(), "Pretty_face_128x128_netG.pth")

torch.save(netD.state_dict(), "Pretty_face_128x128_netD.pth")

首先是复制一下网络参数常数和网络结构,也可以从训练程序import。

代码语言:javascript复制
代码语言:javascript复制
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt



class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            # nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            # 此时 Hin == Wout ==1
            nn.ConvTranspose2d(in_channels=nz, out_channels=ngf*8, kernel_size=(4, 4), stride=(1, 1),
                               padding=(0, 0), output_padding=(0, 0), bias=False),
            # ConvTranspose2d 后 新的 长和 宽:
            # Hout = (Hin - 1 ) x stride[0] - 2 x padding[0]   Kernel_size[0]   output_padding[0]
            # Wout = (Win - 1 ) x stride[1] - 2 x padding[1]   Kernel_size[1]   output_padding[1]
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            # nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=(4, 4), stride=(2, 2),
                               padding=(1, 1), output_padding=(0, 0), bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 64 x 64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 128 x 128
        )

    def forward(self, input_):
        return self.main(input_)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 128 x 128
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 64 x 64
            nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input_):
        return self.main(input_)



real_label = 1.
fake_label = 0.
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
nz = 100
# Size of feature maps in generator
# ngf = 64
ngf = 64
ndf = 64
nc = 3
nz = 100

接着导入训练好的生成网络和判别网络

代码语言:javascript复制
netG = Generator().to(device)netG.load_state_dict(torch.load("Pretty_face_128x128_netG.pth"))

netD = Discriminator().to(device)
netD.load_state_dict(torch.load("Pretty_face_128x128_netD.pth"))

设置随机种子,这步不是必须的。

代码语言:javascript复制
# Set random seed for reproducibilitymanualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
代码语言:javascript复制

最后循环生成伪造的图片,并利用判别网络挑选稍微真一点的图片保存。

代码语言:javascript复制
while True:
    if i >1000: break # 只伪造1000张
    
    # 利用生成网络由尺寸为 1x100x1x1 的张量生成 伪造图片
    fixed_noise = torch.randn(1, nz, 1, 1, device=device)
    fake0 = netG(fixed_noise.detach())
    fake = fake0.cpu() # GPU 张量复制到CPU
   
   #判别网络给伪造图片打分(0.0~1.0 )
    netD_output = netD(fake0).item() 
    
    if netD_output > 0.2: # 只挑选分数稍微高一点的保存到磁盘
        print(i, netD_output)
        img0 = vutils.make_grid(fake, padding=2, normalize=True)
        img0 = img0.numpy()
        img = np.transpose(img0, (1, 2, 0))
        #plt.imshow(img)
        plt.imsave(f"Pretty_128x128\fake{i}.png", img)
        plt.close("None")
        i  = 1

程序生成和挑选出来的伪造图如下,大部分还是很像美女图片

有不少伪造的图片脸上有色斑,甚至有秃头,这些是从训练集带来的特征,如果不喜欢这些特征可以在训练之前从数据集中删除相应的图片。

0 人点赞