TensorFlow 2 和 Keras 高级深度学习:6~10

2023-04-26 19:24:26 浏览数 (3)

原文:Advanced Deep Learning with TensorFlow 2 and Keras 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。

六、纠缠表示 GAN

正如我们已经探索的那样,GAN 可以通过学习数据分布来产生有意义的输出。 但是,无法控制所生成输出的属性。 GAN 的一些变体,例如条件 GANCGAN)和辅助分类器 GANACGAN),如前两章所讨论的,都可以训练生成器,该生成器可以合成特定的输出。 例如,CGAN 和 ACGAN 都可以诱导生成器生成特定的 MNIST 数字。 这可以通过同时使用 100 维噪声代码和相应的一号热标签作为输入来实现。 但是,除了单热标签外,我们没有其他方法可以控制生成的输出的属性。

有关 CGAN 和 ACGAN 的评论,请参阅“第 4 章”,“生成对抗网络(GANs)”和“第 5 章”,“改进的 GANs”。

在本章中,我们将介绍使我们能够修改生成器输出的 GAN 的变体。 在 MNIST 数据集的上下文中,除了要生成的数字外,我们可能会发现我们想要控制书写样式。 这可能涉及所需数字的倾斜度或宽度。 换句话说,GAN 也可以学习纠缠的潜在代码或表示形式,我们可以使用它们来改变生成器输出的属性。 解开的代码或表示形式是张量,可以在不影响其他属性的情况下更改输出数据的特定特征或属性。

在本章的第一部分中,我们将讨论《InfoGAN:通过最大化生成对抗网络的信息进行可解释的表示学习》[1],这是 GAN 的扩展。 InfoGAN 通过最大化输入代码和输出观察值之间的互信息来以无监督的方式学习解缠结的表示形式。 在 MNIST 数据集上,InfoGAN 从数字数据集中解开了写作风格。

在本章的以下部分中,我们还将讨论《栈式生成对抗网络或 StackedGAN》[2],这是 GAN 的另一种扩展。

StackedGAN 使用预训练的编码器或分类器,以帮助解开潜在代码。 StackedGAN 可以看作是一堆模型,每个模型都由编码器和 GAN 组成。 通过使用相应编码器的输入和输出数据,以对抗性方式训练每个 GAN。

总之,本章的目的是介绍:

  • 纠缠表示的概念
  • InfoGAN 和 StackedGAN 的原理
  • 使用tf.keras实现 InfoGAN 和 StackedGAN

让我们从讨论纠缠的表示开始。

1. 纠缠表示

最初的 GAN 能够产生有意义的输出,但是缺点是它的属性无法控制。 例如,如果我们训练 GAN 来学习名人面孔的分布,则生成器将产生名人形象的新图像。 但是,没有任何方法可以影响生成器有关所需脸部的特定属性。 例如,我们无法向生成器询问女性名人的脸,该女性名人是黑发,白皙的肤色,棕色的眼睛,微笑着。 这样做的根本原因是因为我们使用的 100 维噪声代码纠缠了生成器输出的所有显着属性。 我们可以回想一下,在tf.keras中,100-dim代码是由均匀噪声分布的随机采样生成的:

代码语言:javascript复制
 # generate fake images from noise using generator 
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # generate fake images
        fake_images = generator.predict(noise) 

如果我们能够修改原始 GAN,以便将表示形式分为纠缠的和解缠的可解释的潜在代码向量,则我们将能够告诉生成器合成什么。

“图 6.1.1”向我们展示了一个带纠缠代码的 GAN,以及它的纠缠和解缠表示的混合形式。 在假设的名人脸生成的情况下,使用解开的代码,我们可以指出我们希望生成的脸的性别,发型,面部表情,肤色和肤色。 仍然需要n–dim纠缠代码来表示我们尚未纠缠的所有其他面部属性,例如面部形状,面部毛发,眼镜等,仅是三个示例。 纠缠和解纠缠的代码向量的连接用作生成器的新输入。 级联代码的总维不一定是 100:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hwnvaGo2-1681704311642)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_01.png)]

图 6.1.1:带有纠缠码的 GAN 及其随纠缠码和解缠码的变化。 此示例在名人脸生成的背景下显示

查看上图中的,似乎可以以与原始 GAN 相同的方式优化具有解缠表示的 GAN。 这是因为生成器的输出可以表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yrVASRlJ-1681704311643)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_001.png)] (Equation 6.1.1)

代码z = (z, c)包含两个元素:

  • 类似于 GANsz或噪声向量的不可压缩纠缠噪声代码。
  • 潜在代码c[1]c[2],…,c[L], 代表数据分配的可解译的纠缠码。 所有潜在代码共同表示为c

为简单起见,假定所有潜在代码都是独立的:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mIZLPVOn-1681704311644)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_002.png)] (Equation 6.1.2)

生成器函数x = g(z, c) = g(z)带有不可压缩的噪声代码和潜在代码。 从生成器的角度来看,优化z = (z, c)与优化z相同。

当提出解决方案时,生成器网络将仅忽略解纠结代码所施加的约束。

生成器学习分布p_g(x | c) = p_g(x)。 这实际上将打乱分散表示的目的。

InfoGAN 的关键思想是强制 GAN 不要忽略潜在代码c。 这是通过最大化cg(z, c)之间的相互信息来完成的。 在下一节中,我们将公式化 InfoGAN 的损失函数。

InfoGAN

为了加强对代码的纠缠,InfoGAN 提出了一种针对原始损失函数的正则化函数,该函数可最大化潜在代码cg(z, c)之间的互信息:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5CZVWnIT-1681704311644)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_007.png)] (Equation 6.1.3)

正则化器在生成用于合成伪图像的函数时,会强制生成器考虑潜在代码。 在信息论领域,潜码cg(z, c)之间的互信息定义为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7SifPWKe-1681704311645)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_009.png)] (Equation 6.1.4)

其中H(c)是潜码c的熵,H(c | g(z | c))是观察生成器的输出后c的条件熵, g(z, c)。 熵是对随机变量或事件的不确定性的度量。 例如,在东方升起之类的信息具有较低的熵,而在彩票中赢得大奖具有较高的熵。 可以在“第 13 章”,“使用互信息的无监督学习”中找到有关互信息的更详细讨论。

在“公式 6.1.4”中,最大化互信息意味着在观察生成的输出时,将H(c | g(z | c))最小化或减小潜码中的不确定性。 这是有道理的,因为例如在 MNIST 数据集中,如果 GAN 看到生成器 8 看到了数字 8,则生成器对合成数字 8 变得更有信心。

但是,H(c | g(z | c))很难估计,因为它需要后验P(c | g(z | c)) = P(c | x)的知识,这是我们无法获得的。 为简单起见,我们将使用常规字母x表示数据分布。

解决方法是通过使用辅助分布Q(c | x)估计后验来估计互信息的下界。 InfoGAN 估计相互信息的下限为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ihjOQGnw-1681704311646)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_016.png)] (Equation 6.1.5)

在 InfoGAN 中,H(c)被假定为常数。 因此,使相互信息最大化是使期望最大化的问题。 生成器必须确信已生成具有特定属性的输出。 我们应注意,此期望的最大值为零。 因此,相互信息的下限的最大值为H(c)。 在 InfoGAN 中,离散隐码的Q(c | x)可以由softmax非线性表示。 期望是tf.keras中的负categorical_crossentropy损失。

对于一维连续代码,期望是cx的双整数。 这是由于期望从纠缠的代码分布和生成器分布中采样。 估计期望值的一种方法是通过假设样本是连续数据的良好度量。 因此,损失估计为c log Q(c | x)。 在“第 13 章”,“使用互信息的无监督学习”中,我们将提供对互信息的更精确估计。

为了完成 InfoGAN 的网络,我们应该有Q(c | x)的实现。 为了简单起见,网络 Q 是一个附加到判别器第二到最后一层的辅助网络。 因此,这对原始 GAN 的训练影响很小。

“图 6.1.2”显示了 InfoGAN 网络图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pABb7kwi-1681704311646)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_02.png)]

图 6.1.2 网络图显示 InfoGAN 中的判别器和生成器训练

“表 6.1.1”显示了与 GAN 相比 InfoGAN 的损失函数:

网络

损失函数

编号

GAN

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sPbLJ0rI-1681704311646)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_019.png)]

4.1.1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bXghkCF8-1681704311646)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_020.png)]

4.1.5

InfoGAN

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HOaBEf1W-1681704311647)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_021.png)]

6.1.1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SIHd6m2H-1681704311647)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_022.png)]

6.1.2

对于连续代码,InfoGAN 建议使用λ < 1的值。 在我们的示例中,我们设置λ = 0.5。 对于离散代码,InfoGAN 建议使用λ = 1。

表 6.1.1:GAN 和 InfoGAN 的损失函数之间的比较

InfoGAN 的损失函数与 GAN 的区别是附加项-λI(c; g(z, c)),其中λ是一个小的正常数。 最小化 InfoGAN 的损失函数可以将原始 GAN 的损失最小化,并将互信息最大化I(c; g(z, c))

如果将其应用于 MNIST 数据集,InfoGAN 可以学习解开的离散码和连续码,以修改生成器输出属性。 例如,像 CGAN 和 ACGAN 一样,将使用10-dim一键标签形式的离散代码来指定要生成的数字。 但是,我们可以添加两个连续的代码,一个用于控制书写样式的角度,另一个用于调整笔划宽度。“图 6.1.3”显示了 InfoGAN 中 MNIST 数字的代码。 我们保留较小尺寸的纠缠代码以表示所有其他属性:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9tVQ87he-1681704311647)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_03.png)]

图 6.1.3:MNIST 数据集中 GAN 和 InfoGAN 的代码

在讨论了 InfoGAN 背后的一些概念之后,让我们看一下tf.keras中的 InfoGAN 实现。

在 Keras 中实现 InfoGAN

为了在 MNIST 数据集上实现 InfoGAN,需要对 ACGAN 的基本代码进行一些更改。 如“列表 6.1.1”中突出显示的那样,生成器将纠缠的(z噪声代码)和解纠结的代码(单标签和连续代码)连接起来作为输入:

代码语言:javascript复制
inputs = [inputs, labels]   codes 

generatordiscriminator的构建器函数也在lib文件夹的gan.py中实现。

完整的代码可在 GitHub 上获得。

“列表 6.1.1”:infogan-mnist-6.1.1.py

突出显示了特定于 InfoGAN 的行:

代码语言:javascript复制
def generator(inputs,
              image_size,
              activation='sigmoid',
              labels=None,
              codes=None):
    """Build a Generator Model 
代码语言:javascript复制
 Stack of BN-ReLU-Conv2DTranpose to generate fake images.
    Output activation is sigmoid instead of tanh in [1].
    Sigmoid converges easily. 
代码语言:javascript复制
 Arguments:
        inputs (Layer): Input layer of the generator (the z-vector)
        image_size (int): Target size of one side 
            (assuming square image)
        activation (string): Name of output activation layer
        labels (tensor): Input labels
        codes (list): 2-dim disentangled codes for InfoGAN 
代码语言:javascript复制
 Returns:
        Model: Generator Model
    """
    image_resize = image_size // 4
    # network parameters
    kernel_size = 5
    layer_filters = [128, 64, 32, 1] 
代码语言:javascript复制
 if labels is not None:
        if codes is None:
            # ACGAN labels
            # concatenate z noise vector and one-hot labels
            inputs = [inputs, labels]
        else:
            # infoGAN codes
            # concatenate z noise vector, 
            # one-hot labels and codes 1 & 2
            inputs = [inputs, labels]   codes
        x = concatenate(inputs, axis=1)
    elif codes is not None:
        # generator 0 of StackedGAN
        inputs = [inputs, codes]
        x = concatenate(inputs, axis=1)
    else:
        # default input is just 100-dim noise (z-code)
        x = inputs 
代码语言:javascript复制
 x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x) 
代码语言:javascript复制
 for filters in layer_filters:
        # first two convolution layers use strides = 2
        # the last two use strides = 1
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x) 
代码语言:javascript复制
 if activation is not None:
        x = Activation(activation)(x) 
代码语言:javascript复制
 # generator output is the synthesized image x
    return Model(inputs, x, name='generator') 

“列表 6.1.2”显示了具有原始默认 GAN 输出的判别器和 Q 网络。 高亮显示了三个辅助输出,它们对应于离散代码(用于单热标签)softmax预测的和给定输入 MNIST 数字图像的连续代码概率。

“列表 6.1.2”:infogan-mnist-6.1.1.py

突出显示了特定于 InfoGAN 的行:

代码语言:javascript复制
def discriminator(inputs,
                  activation='sigmoid',
                  num_labels=None,
                  num_codes=None):
    """Build a Discriminator Model 
代码语言:javascript复制
 Stack of LeakyReLU-Conv2D to discriminate real from fake
    The network does not converge with BN so it is not used here
    unlike in [1] 
代码语言:javascript复制
 Arguments:
        inputs (Layer): Input layer of the discriminator (the image)
        activation (string): Name of output activation layer
        num_labels (int): Dimension of one-hot labels for ACGAN & InfoGAN
        num_codes (int): num_codes-dim Q network as output 
                    if StackedGAN or 2 Q networks if InfoGAN 
代码语言:javascript复制
 Returns:
        Model: Discriminator Model
    """
    kernel_size = 5
    layer_filters = [32, 64, 128, 256] 
代码语言:javascript复制
 x = inputs
    for filters in layer_filters:
        # first 3 convolution layers use strides = 2
        # last one uses strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x) 
代码语言:javascript复制
 x = Flatten()(x)
    # default output is probability that the image is real
    outputs = Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = Activation(activation)(outputs) 
代码语言:javascript复制
 if num_labels:
        # ACGAN and InfoGAN have 2nd output
        # 2nd output is 10-dim one-hot vector of label
        layer = Dense(layer_filters[-2])(x)
        labels = Dense(num_labels)(layer)
        labels = Activation('softmax', name='label')(labels)
        if num_codes is None:
            outputs = [outputs, labels]
        else:
            # InfoGAN have 3rd and 4th outputs
            # 3rd output is 1-dim continous Q of 1st c given x
            code1 = Dense(1)(layer)
            code1 = Activation('sigmoid', name='code1')(code1)
            # 4th output is 1-dim continuous Q of 2nd c given x
            code2 = Dense(1)(layer)
            code2 = Activation('sigmoid', name='code2')(code2) 
代码语言:javascript复制
 outputs = [outputs, labels, code1, code2]
    elif num_codes is not None:
        # StackedGAN Q0 output
        # z0_recon is reconstruction of z0 normal distribution
        z0_recon =  Dense(num_codes)(x)
        z0_recon = Activation('tanh', name='z0')(z0_recon)
        outputs = [outputs, z0_recon] 
代码语言:javascript复制
 return Model(inputs, outputs, name='discriminator') 

“图 6.1.4”显示了tf.keras中的 InfoGAN 模型:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ccSWJ7KB-1681704311648)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_04.png)]

图 6.1.4:InfoGAN Keras 模型

建立判别器和对抗模型还需要进行许多更改。 更改取决于所使用的损失函数。 原始的判别器损失函数binary_crossentropy,用于离散码的categorical_crossentropy和每个连续码的mi_loss函数构成了整体损失函数。 除mi_loss函数的权重为 0.5(对应于连续代码的λ = 0.5)外,每个损失函数的权重均为 1.0。

“列表 6.1.3”突出显示了所做的更改。 但是,我们应该注意,通过使用构造器函数,判别器被实例化为:

代码语言:javascript复制
 # call discriminator builder with 4 outputs:
    # source, label, and 2 codes
    discriminator = gan.discriminator(inputs,
                                      num_labels=num_labels,
                                      num_codes=2) 

生成器通过以下方式创建:

代码语言:javascript复制
 # call generator with inputs, 
    # labels and codes as total inputs to generator
    generator = gan.generator(inputs,
                              image_size,
                              labels=labels,
                              codes=[code1, code2]) 

“列表 6.1.3”:infogan-mnist-6.1.1.py

以下代码演示了互信息损失函数以及建立和训练 InfoGAN 判别器和对抗网络的过程:

代码语言:javascript复制
def mi_loss(c, q_of_c_given_x):
    """ Mutual information, Equation 5 in [2],
        assuming H(c) is constant
    """
    # mi_loss = -c * log(Q(c|x))
    return K.mean(-K.sum(K.log(q_of_c_given_x   K.epsilon()) * c,
                               axis=1)) 
代码语言:javascript复制
def build_and_train_models(latent_size=100):
    """Load the dataset, build InfoGAN discriminator,
    generator, and adversarial models.
    Call the InfoGAN train routine.
    """ 
代码语言:javascript复制
 # load MNIST dataset
    (x_train, y_train), (_, _) = mnist.load_data() 
代码语言:javascript复制
 # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255 
代码语言:javascript复制
 # train labels
    num_labels = len(np.unique(y_train))
    y_train = to_categorical(y_train) 
代码语言:javascript复制
 model_name = "infogan_mnist"
    # network parameters
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
    label_shape = (num_labels, )
    code_shape = (1, ) 
代码语言:javascript复制
 # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    # call discriminator builder with 4 outputs: 
    # source, label, and 2 codes
    discriminator = gan.discriminator(inputs,
                                      num_labels=num_labels,
                                      num_codes=2)
    # [1] uses Adam, but discriminator converges easily with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # loss functions: 1) probability image is real
    # (binary crossentropy)
    # 2) categorical cross entropy image label,
    # 3) and 4) mutual information loss
    loss = ['binary_crossentropy',
            'categorical_crossentropy',
            mi_loss,
            mi_loss]
    # lamda or mi_loss weight is 0.5
    loss_weights = [1.0, 1.0, 0.5, 0.5]
    discriminator.compile(loss=loss,
                          loss_weights=loss_weights,
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary() 
代码语言:javascript复制
 # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    labels = Input(shape=label_shape, name='labels')
    code1 = Input(shape=code_shape, name="code1")
    code2 = Input(shape=code_shape, name="code2")
    # call generator with inputs, 
    # labels and codes as total inputs to generator
    generator = gan.generator(inputs,
                              image_size,
                              labels=labels,
                              codes=[code1, code2])
    generator.summary() 
代码语言:javascript复制
 # build adversarial model = generator   discriminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    discriminator.trainable = False
    # total inputs = noise code, labels, and codes
    inputs = [inputs, labels, code1, code2]
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    # same loss as discriminator
    adversarial.compile(loss=loss,
                        loss_weights=loss_weights,
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary() 
代码语言:javascript复制
 # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    data = (x_train, y_train)
    params = (batch_size,
              latent_size,
              train_steps,
              num_labels,
              model_name)
    train(models, data, params) 

就训练而言,我们可以看到 InfoGAN 与 ACGAN 类似,除了我们需要为连续代码提供cc是从正态分布中提取的,标准差为 0.5,平均值为 0.0。 我们将对伪数据使用随机采样的标签,对实际数据使用数据集的类标签来表示离散的潜在代码。

“列表 6.1.4”突出显示了对训练函数所做的更改。 与以前的所有 GAN 相似,判别器和生成器(通过对抗性训练)被交替训练。 在对抗训练期间,判别器的权重被冻结。

通过使用gan.py plot_images()函数,样本生成器输出图像每 500 个间隔步被保存一次。

“列表 6.1.4”:infogan-mnist-6.1.1.py

代码语言:javascript复制
def train(models, data, params):
    """Train the Discriminator and Adversarial networks 
代码语言:javascript复制
 Alternately train discriminator and adversarial networks by batch.
    Discriminator is trained first with real and fake images,
    corresponding one-hot labels and continuous codes.
    Adversarial is trained next with fake images pretending 
    to be real, corresponding one-hot labels and continous codes.
    Generate sample images per save_interval. 
代码语言:javascript复制
 # Arguments
        models (Models): Generator, Discriminator, Adversarial models
        data (tuple): x_train, y_train data
        params (tuple): Network parameters
    """
    # the GAN models
    generator, discriminator, adversarial = models
    # images and their one-hot labels
    x_train, y_train = data
    # network parameters
    batch_size, latent_size, train_steps, num_labels, model_name = 
            params
    # the generator image is saved every 500 steps
    save_interval = 500
    # noise vector to see how the generator output 
    # evolves during training
    noise_input = np.random.uniform(-1.0,
                                    1.0,
                                    size=[16, latent_size])
    # random class labels and codes
    noise_label = np.eye(num_labels)[np.arange(0, 16) % num_labels]
    noise_code1 = np.random.normal(scale=0.5, size=[16, 1])
    noise_code2 = np.random.normal(scale=0.5, size=[16, 1])
    # number of elements in train dataset
    train_size = x_train.shape[0]
    print(model_name,
          "Labels for generated images: ",
          np.argmax(noise_label, axis=1)) 
代码语言:javascript复制
 for i in range(train_steps):
        # train the discriminator for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=0.0)
        # randomly pick real images and 
        # corresponding labels from dataset 
        rand_indexes = np.random.randint(0,
                                         train_size,
                                         size=batch_size)
        real_images = x_train[rand_indexes]
        real_labels = y_train[rand_indexes]
        # random codes for real images
        real_code1 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        real_code2 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        # generate fake images, labels and codes
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,
                                                          batch_size)]
        fake_code1 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        fake_code2 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        inputs = [noise, fake_labels, fake_code1, fake_code2]
        fake_images = generator.predict(inputs)
        # real   fake images = 1 batch of train data
        x = np.concatenate((real_images, fake_images))
        labels = np.concatenate((real_labels, fake_labels))
        codes1 = np.concatenate((real_code1, fake_code1))
        codes2 = np.concatenate((real_code2, fake_code2))
        # label real and fake images
        # real images label is 1.0
        y = np.ones([2 * batch_size, 1])
        # fake images label is 0.0
        y[batch_size:, :] = 0
        # train discriminator network, 
        # log the loss and label accuracy
        outputs = [y, labels, codes1, codes2]
        # metrics = ['loss', 'activation_1_loss', 'label_loss',
        # 'code1_loss', 'code2_loss', 'activation_1_acc',
        # 'label_acc', 'code1_acc', 'code2_acc']
        # from discriminator.metrics_names
        metrics = discriminator.train_on_batch(x, outputs)
        fmt = "%d: [discriminator loss: %f, label_acc: %f]"
        log = fmt % (i, metrics[0], metrics[6])
        # train the adversarial network for 1 batch
        # 1 batch of fake images with label=1.0 and
        # corresponding one-hot label or class   random codes
        # since the discriminator weights are frozen 
        # in adversarial network only the generator is trained
        # generate fake images, labels and codes
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,
                                                          batch_size)]
        fake_code1 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        fake_code2 = np.random.normal(scale=0.5,
                                      size=[batch_size, 1])
        # label fake images as real
        y = np.ones([batch_size, 1])
        # train the adversarial network 
        # note that unlike in discriminator training,
        # we do not save the fake images in a variable
        # the fake images go to the discriminator
        # input of the adversarial for classification
        # log the loss and label accuracy
        inputs = [noise, fake_labels, fake_code1, fake_code2]
        outputs = [y, fake_labels, fake_code1, fake_code2]
        metrics  = adversarial.train_on_batch(inputs, outputs)
        fmt = "%s [adversarial loss: %f, label_acc: %f]"
        log = fmt % (log, metrics[0], metrics[6])
        print(log)
        if (i   1) % save_interval == 0:
            # plot generator images on a periodic basis
            gan.plot_images(generator,
                            noise_input=noise_input,
                            noise_label=noise_label,
                            noise_codes=[noise_code1, noise_code2],
                            show=False,
                            step=(i   1),
                            model_name=model_name)
    # save the model after training the generator
    # the trained generator can be reloaded for
    # future MNIST digit generation
    generator.save(model_name   ".h5") 

给定 InfoGAN 的tf.keras实现,下一个部分介绍具有解缠结属性的生成器 MNIST 输出。

InfoGAN 的生成器输出

与以前提供给我们的所有 GAN 相似,我们已经对 InfoGAN 进行了 40,000 步的训练。 训练完成后,我们可以运行 InfoGAN 生成器,以使用infogan_mnist.h5文件中保存的模型生成新输出。 进行以下验证:

通过将离散标签从 0 更改为 9,可生成数字 0 至 9。 两个连续代码都设置为零。 结果显示在“图 6.1.5”中。 我们可以看到,InfoGAN 离散代码可以控制生成器产生的数字:

代码语言:javascript复制
python3 infogan-mnist-6.1.1.py --generator=infogan_mnist.h5
--digit=0 --code1=0 --code2=0 

代码语言:javascript复制
python3 infogan-mnist-6.1.1.py --generator=infogan_mnist.h5
--digit=9 --code1=0 --code2=0 

在“图 6.1.5”中,我们可以看到 InfoGAN 生成的图像:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mFyg67LO-1681704311648)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_05.png)]

图 6.1.5:当离散代码从 0 变为 9 时,InfoGAN 生成的图像都被设置为零。

检查第一个连续代码的效果,以了解哪个属性已受到影响。 我们将 0 到 9 的第一个连续代码从 -2.0 更改为 2.0。 第二个连续代码设置为 0.0。 “图 6.1.6”显示了第一个连续代码控制数字的粗细:

代码语言:javascript复制
python3 infogan-mnist-6.1.1.py --generator=infogan_mnist.h5
--digit=0 --code1=0 --code2=0 --p1 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KS2VfpD5-1681704311648)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_06.png)]

图 6.1.6:InfoGAN 作为第一个连续代码将 0 到 9 的数字从-2.0 更改为 2.0。第二个连续代码设置为零。 第一个连续代码控制数字的粗细

与上一步的类似,但更多地关注第二个连续代码。“图 6.1.7”显示第二个连续代码控制书写样式的旋转角度(倾斜):

代码语言:javascript复制
python3 infogan-mnist-6.1.1.py --generator=infogan_mnist.h5
--digit=0 --code1=0 --code2=0 --p2 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-isY2Z7U3-1681704311648)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_07.png)]

图 6.1.7:InfoGAN 生成的图像作为第二个连续代码从 0 到 9 的数字从 -2.0 变为 2.0。第一个连续代码设置为零。 第二个连续代码控制书写样式的旋转角度(倾斜)

从这些验证结果中,我们可以看到,除了生成 MNIST 外观数字的能力之外,InfoGAN 还扩展了条件 GAN(如 CGAN 和 ACGAN)的功能。 网络自动学习了两个可以控制生成器输出的特定属性的任意代码。 有趣的是,如果我们将连续代码的数量增加到 2 以上,可以控制哪些附加属性,可以通过将“列表 6.1.1”的突出显示行中的代码扩展到列表 6.1.4 来实现。

本节中的结果表明,可以通过最大化代码和数据分布之间的互信息来纠缠生成器输出的属性。 在以下部分中,介绍了一种不同的解缠结方法。 StackedGAN 的想法是在特征级别注入代码。

2. StackedGAN

与 InfoGAN 一样,StackedGAN 提出了一种用于分解潜在表示的方法,以调节生成器输出。 但是,StackedGAN 使用不同的方法来解决此问题。 与其学习如何调节噪声以产生所需的输出,不如将 StackedGAN 分解为 GAN 栈。 每个 GAN 均以通常的区分对手的方式进行独立训练,并带有自己的潜在代码。

“图 6.2.1”向我们展示了 StackedGAN 在假设名人脸生成的背景下如何工作,假设已经训练了编码器网络对名人脸进行分类:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HwlmPKSS-1681704311649)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_08.png)]

图 6.2.1:在名人脸生成的背景下 StackedGAN 的基本思想。 假设有一个假设的深层编码器网络可以对名人脸进行分类,那么 StackedGAN 可以简单地反转编码器的过程

编码器网络是由一堆简单的编码器组成的,Encoder[i],其中i = 0 … n-1对应n个特征。 每个编码器都提取某些面部特征。 例如,Encoder[0]可能是发型特征的编码器,Feature[1]。 所有简单的编码器都有助于使整个编码器执行正确的预测。

StackedGAN 背后的想法是,如果我们想构建一个可生成假名人面孔的 GAN,则只需将编码器反转即可。 StackedGAN 由一堆更简单的 GAN 组成,GAN[i],其中i = 0 … n-1n个特征相对应。 每个GAN[i]学会反转其相应编码器Encoder[i]的过程。 例如,GAN[0]从假发型特征生成假名人脸,这是Encoder[0]处理的逆过程。

每个GAN[i]使用潜码z[i],以调节其生成器输出。 例如,潜在代码z[0]可以将发型从卷曲更改为波浪形。 GAN 的栈也可以用作合成假名人面孔的对象,从而完成整个编码器的逆过程。 每个GAN[i]z[i]的潜在代码都可以用来更改假名人面孔的特定属性。

有了 StackedGAN 的工作原理的关键思想,让我们继续下一节,看看如何在tf.keras中实现它。

Keras 中 StackedGAN 的实现

StackedGAN 的详细网络模型可以在“图 6.2.2”中看到。 为简洁起见,每个栈仅显示两个编码器 GAN。 该图最初可能看起来很复杂,但这只是一个编码器 GAN 的重复,这意味着如果我们了解如何训练一个编码器 GAN,其余的将使用相同的概念。

在本节中,我们假设 StackedGAN 是为 MNIST 数字生成而设计的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wHlq7lV8-1681704311649)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_09.png)]

图 6.2.2:StackedGAN 包含编码器和 GAN 的栈。 对编码器进行预训练以执行分类。 Generator[1]G[1]学会合成特征f[1f],假标签y[f]和潜在代码z[1f]Generator[0]G[0]均使用这两个伪特征f[1f]生成伪图像和潜在代码z[0f]

StackedGAN 以编码器开头。 它可能是训练有素的分类器,可以预测正确的标签。 可以将中间特征向量f[1r]用于 GAN 训练。 对于 MNIST,我们可以使用基于 CNN 的分类器,类似于在“第 1 章”,“Keras 高级深度学习”中讨论的分类器。

“图 6.2.3”显示了编码器及其在tf.keras中的网络模型实现:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hi8p71TU-1681704311649)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_10.png)]

图 6.2.3:StackedGAN 中的编码器是一个基于 CNN 的简单分类器

“列表 6.2.1”显示了上图的tf.keras代码。 它与“第 1 章”,“Keras 高级深度学习”中的基于 CNN 的分类器相似,不同之处在于,我们使用Dense层来提取256-dim 特征。 有两个输出模型,Encoder[0]Encoder[1]。 两者都将用于训练 StackedGAN。

“列表 6.2.1”:stackedgan-mnist-6.2.1.py

代码语言:javascript复制
def build_encoder(inputs, num_labels=10, feature1_dim=256):
    """ Build the Classifier (Encoder) Model sub networks 
代码语言:javascript复制
 Two sub networks: 
    1) Encoder0: Image to feature1 (intermediate latent feature)
    2) Encoder1: feature1 to labels 
代码语言:javascript复制
 # Arguments
        inputs (Layers): x - images, feature1 - 
            feature1 layer output
        num_labels (int): number of class labels
        feature1_dim (int): feature1 dimensionality 
代码语言:javascript复制
 # Returns
        enc0, enc1 (Models): Description below 
    """
    kernel_size = 3
    filters = 64 
代码语言:javascript复制
 x, feature1 = inputs
    # Encoder0 or enc0
    y = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu')(x)
    y = MaxPooling2D()(y)
    y = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu')(y)
    y = MaxPooling2D()(y)
    y = Flatten()(y)
    feature1_output = Dense(feature1_dim, activation='relu')(y) 
代码语言:javascript复制
 # Encoder0 or enc0: image (x or feature0) to feature1 
    enc0 = Model(inputs=x, outputs=feature1_output, name="encoder0") 
代码语言:javascript复制
 # Encoder1 or enc1
    y = Dense(num_labels)(feature1)
    labels = Activation('softmax')(y)
    # Encoder1 or enc1: feature1 to class labels (feature2)
    enc1 = Model(inputs=feature1, outputs=labels, name="encoder1") 
代码语言:javascript复制
 # return both enc0 and enc1
    return enc0, enc1 

Encoder[0]输出f[1r]是我们想要的256维特征向量生成器 1 学习合成。 可以将用作Encoder[0]E[0]的辅助输出。 训练整个编码器以对 MNIST 数字进行分类,即x[r]。 正确的标签y[r]Encoder[1]E[1]。 在此过程中,学习了的中间特征集f[1r],可用于Generator[0]训练。 当针对该编码器训练 GAN 时,下标r用于强调和区分真实数据与伪数据。

假设编码器输入(x[r])中间特征(f[1r])和标签(y[r]),每个 GAN 都采用通常的区分-对抗方式进行训练。 损失函数由“表 6.2.1”中的“公式 6.2.1”至“公式 6.2.5”给出。“公式 6.2.1”和“公式 6.2.2”是通用 GAN 的常见损失函数。 StackedGAN 具有两个附加损失函数,即有条件

网络

损失函数

编号

GAN

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PoMQh9qC-1681704311649)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_030.png)]

4.1.1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MpyIoFGy-1681704311650)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_031.png)]

4.1.5

栈式

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LXe8rb0W-1681704311650)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_032.png)]

6.2.1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yzzweb4Z-1681704311650)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_033.png)]

6.2.2

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Cl62on7I-1681704311651)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_034.png)]

6.2.3

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YHxImlpO-1681704311651)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_035.png)]

6.2.4

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mgHf1uTI-1681704311651)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_036.png)]

6.2.5

其中λ1, λ2, λ3是权重,i是编码器和 GAN ID

表 6.2.1:GAN 和 StackedGAN 的损失函数之间的比较。 ~p_data表示从相应的编码器数据(输入,特征或输出)采样

条件“公式 6.2.3”中的损失函数L_i^(G_cond)确保生成器不会忽略输入f[i 1], 当从输入噪声代码z[i]合成输出f[i]时。 编码器Encoder[i]必须能够通过反转生成器的过程Generator[i]来恢复生成器输入。 通过L2或欧几里德距离(均方误差MSE))来测量生成器输入和使用编码器恢复的输入之间的差异。

“图 6.2.4”显示了L_0^(G_cond)计算所涉及的网络元素:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZDUnOIVb-1681704311651)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_11.png)]

图 6.2.4:图 6.2.3 的简化版本,仅显示L_0^(G_cond)计算中涉及的网络元素

但是,条件损失函数引入了一个新问题。 生成器忽略输入噪声代码z[i],仅依赖f[i 1]。 熵损失函数“公式 6.2.4”中的L_0^(G_ent)确保生成器不会忽略噪声代码z[i]Q 网络从生成器的输出中恢复噪声代码。 恢复的噪声和输入噪声之间的差异也通过L2或欧几里德距离(MSE)进行测量。

“图 6.2.5”显示了L_0^(G_ent)计算中涉及的网络元素:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZPqJfqRX-1681704311651)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_12.png)]

图 6.2.5:图 6.2.3 的简单版本仅向我们显示了L_0^(G_ent)计算中涉及的网络元素

最后的损失函数类似于通常的 GAN 损失。 它包括判别器损失L_i^(D)和生成器(通过对抗性)损失L_i^(G_adv)。“图 6.2.6”显示了 GAN 损失所涉及的元素。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sJDxjSuE-1681704311652)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_13.png)]

图 6.2.6:图 6.2.3 的简化版本,仅显示了L_i^(D)L_0^(G_adv)计算中涉及的网络元素

在“公式 6.2.5”中,三个生成器损失函数的加权和为最终生成器损失函数。 在我们将要介绍的 Keras 代码中,除的熵损失设置为 10.0 之外,所有权重都设置为 1.0。 在“公式 6.2.1”至“公式 6.2.5”中,i是指编码器和 GAN 组 ID 或级别。 在原始论文中,首先对网络进行独立训练,然后进行联合训练。 在独立训练期间,编码器将首先进行训练。 在联合训练期间,将使用真实数据和虚假数据。

tf.keras中 StackedGAN 生成器和判别器的实现只需进行少量更改即可提供辅助点来访问中间特征。“图 6.2.7”显示了生成器tf.keras模型。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3q8iuc9k-1681704311652)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_14.png)]

图 6.2.7:Keras 中的 StackedGAN 生成器模型

“列表 6.2.2”说明了构建与Generator[0]Generator[1]相对应的两个生成器(gen0gen1)的函数。 gen1生成器由三层Dense层组成,标签为和噪声代码z[1f]作为输入。 第三层生成伪造的f[1f]特征。 gen0生成器类似于我们介绍的其他 GAN 生成器,可以使用gan.py中的生成器生成器实例化:

代码语言:javascript复制
# gen0: feature1   z0 to feature0 (image)
gen0 = gan.generator(feature1, image_size, codes=z0) 

gen0输入为f[1]特征,并且噪声代码为z[0]。 输出是生成的伪图像x[f]

“列表 6.2.2”:stackedgan-mnist-6.2.1.py

代码语言:javascript复制
def build_generator(latent_codes, image_size, feature1_dim=256):
    """Build Generator Model sub networks 
代码语言:javascript复制
 Two sub networks: 1) Class and noise to feature1 
        (intermediate feature)
        2) feature1 to image 
代码语言:javascript复制
 # Arguments
        latent_codes (Layers): dicrete code (labels),
            noise and feature1 features
        image_size (int): Target size of one side
            (assuming square image)
        feature1_dim (int): feature1 dimensionality 
代码语言:javascript复制
 # Returns
        gen0, gen1 (Models): Description below
    """ 
代码语言:javascript复制
 # Latent codes and network parameters
    labels, z0, z1, feature1 = latent_codes
    # image_resize = image_size // 4
    # kernel_size = 5
    # layer_filters = [128, 64, 32, 1] 
代码语言:javascript复制
 # gen1 inputs
    inputs = [labels, z1]      # 10   50 = 62-dim
    x = concatenate(inputs, axis=1)
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    fake_feature1 = Dense(feature1_dim, activation='relu')(x)
    # gen1: classes and noise (feature2   z1) to feature1
    gen1 = Model(inputs, fake_feature1, name='gen1') 
代码语言:javascript复制
 # gen0: feature1   z0 to feature0 (image)
    gen0 = gan.generator(feature1, image_size, codes=z0) 
代码语言:javascript复制
 return gen0, gen1 

“图 6.2.8”显示了判别器tf.keras模型:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-e3dKP9Nt-1681704311652)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_15.png)]

图 6.2.8:Keras 中的 StackedGAN 判别器模型

我们提供函数来构建Discriminator[0]Discriminator[1]dis0dis1)。 dis0判别器类似于 GAN 判别器,除了特征向量输入和辅助网络Q[0],其恢复z[0]gan.py中的构造器函数用于创建dis0

代码语言:javascript复制
dis0 = gan.discriminator(inputs, num_codes=z_dim) 

dis1判别器由三层 MLP 组成,如清单 6.2.3 所示。 最后一层将区分为真假f[1]Q[1]网络共享dis1的前两层。 其第三层回收z[1]

“列表 6.2.3”:stackedgan-mnist-6.2.1.py

代码语言:javascript复制
def build_discriminator(inputs, z_dim=50):
    """Build Discriminator 1 Model 
代码语言:javascript复制
 Classifies feature1 (features) as real/fake image and recovers
    the input noise or latent code (by minimizing entropy loss) 
代码语言:javascript复制
 # Arguments
        inputs (Layer): feature1
        z_dim (int): noise dimensionality 
代码语言:javascript复制
 # Returns
        dis1 (Model): feature1 as real/fake and recovered latent code
    """ 
代码语言:javascript复制
 # input is 256-dim feature1
    x = Dense(256, activation='relu')(inputs)
    x = Dense(256, activation='relu')(x) 
代码语言:javascript复制
 # first output is probability that feature1 is real
    f1_source = Dense(1)(x)
    f1_source = Activation('sigmoid',
                           name='feature1_source')(f1_source) 
代码语言:javascript复制
 # z1 reonstruction (Q1 network)
    z1_recon = Dense(z_dim)(x)
    z1_recon = Activation('tanh', name='z1')(z1_recon) 
代码语言:javascript复制
 discriminator_outputs = [f1_source, z1_recon]
    dis1 = Model(inputs, discriminator_outputs, name='dis1')
    return dis1 

有了所有可用的构建器函数,StackedGAN 就会在“列表 6.2.4”中进行组装。 在训练 StackedGAN 之前,对编码器进行了预训练。 请注意,我们已经在对抗模型训练中纳入了三个生成器损失函数(对抗,条件和熵)。Q网络与判别器模型共享一些公共层。 因此,其损失函数也被纳入判别器模型训练中。

“列表 6.2.4”:stackedgan-mnist-6.2.1.py

代码语言:javascript复制
def build_and_train_models():
    """Load the dataset, build StackedGAN discriminator,
    generator, and adversarial models.
    Call the StackedGAN train routine.
    """
    # load MNIST dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data() 
代码语言:javascript复制
 # reshape and normalize images
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255 
代码语言:javascript复制
 x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
    x_test = x_test.astype('float32') / 255 
代码语言:javascript复制
 # number of labels
    num_labels = len(np.unique(y_train))
    # to one-hot vector
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test) 
代码语言:javascript复制
 model_name = "stackedgan_mnist"
    # network parameters
    batch_size = 64
    train_steps = 10000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
    label_shape = (num_labels, )
    z_dim = 50
    z_shape = (z_dim, )
    feature1_dim = 256
    feature1_shape = (feature1_dim, ) 
代码语言:javascript复制
 # build discriminator 0 and Q network 0 models
    inputs = Input(shape=input_shape, name='discriminator0_input')
    dis0 = gan.discriminator(inputs, num_codes=z_dim)
    # [1] uses Adam, but discriminator converges easily with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # loss fuctions: 1) probability image is real (adversarial0 loss)
    # 2) MSE z0 recon loss (Q0 network loss or entropy0 loss)
    loss = ['binary_crossentropy', 'mse']
    loss_weights = [1.0, 10.0]
    dis0.compile(loss=loss,
                 loss_weights=loss_weights,
                 optimizer=optimizer,
                 metrics=['accuracy'])
    dis0.summary() # image discriminator, z0 estimator 
代码语言:javascript复制
 # build discriminator 1 and Q network 1 models
    input_shape = (feature1_dim, )
    inputs = Input(shape=input_shape, name='discriminator1_input')
    dis1 = build_discriminator(inputs, z_dim=z_dim )
    # loss fuctions: 1) probability feature1 is real 
    # (adversarial1 loss)
    # 2) MSE z1 recon loss (Q1 network loss or entropy1 loss)
    loss = ['binary_crossentropy', 'mse']
    loss_weights = [1.0, 1.0]
    dis1.compile(loss=loss,
                 loss_weights=loss_weights,
                 optimizer=optimizer,
                 metrics=['accuracy'])
    dis1.summary() # feature1 discriminator, z1 estimator 
代码语言:javascript复制
 # build generator models
    feature1 = Input(shape=feature1_shape, name='feature1_input')
    labels = Input(shape=label_shape, name='labels')
    z1 = Input(shape=z_shape, name="z1_input")
    z0 = Input(shape=z_shape, name="z0_input")
    latent_codes = (labels, z0, z1, feature1)
    gen0, gen1 = build_generator(latent_codes, image_size)
    gen0.summary() # image generator
    gen1.summary() # feature1 generator 
代码语言:javascript复制
 # build encoder models
    input_shape = (image_size, image_size, 1)
    inputs = Input(shape=input_shape, name='encoder_input')
    enc0, enc1 = build_encoder((inputs, feature1), num_labels)
    enc0.summary() # image to feature1 encoder
    enc1.summary() # feature1 to labels encoder (classifier)
    encoder = Model(inputs, enc1(enc0(inputs)))
    encoder.summary() # image to labels encoder (classifier) 
代码语言:javascript复制
 data = (x_train, y_train), (x_test, y_test)
    train_encoder(encoder, data, model_name=model_name) 
代码语言:javascript复制
 # build adversarial0 model =
    # generator0   discriminator0   encoder0
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # encoder0 weights frozen
    enc0.trainable = False
    # discriminator0 weights frozen
    dis0.trainable = False
    gen0_inputs = [feature1, z0]
    gen0_outputs = gen0(gen0_inputs)
    adv0_outputs = dis0(gen0_outputs)   [enc0(gen0_outputs)]
    # feature1   z0 to prob feature1 is 
    # real   z0 recon   feature0/image recon
    adv0 = Model(gen0_inputs, adv0_outputs, name="adv0")
    # loss functions: 1) prob feature1 is real (adversarial0 loss)
    # 2) Q network 0 loss (entropy0 loss)
    # 3) conditional0 loss
    loss = ['binary_crossentropy', 'mse', 'mse']
    loss_weights = [1.0, 10.0, 1.0]
    adv0.compile(loss=loss,
                 loss_weights=loss_weights,
                 optimizer=optimizer,
                 metrics=['accuracy'])
    adv0.summary() 
代码语言:javascript复制
 # build adversarial1 model = 
    # generator1   discriminator1   encoder1
    # encoder1 weights frozen
    enc1.trainable = False
    # discriminator1 weights frozen
    dis1.trainable = False
    gen1_inputs = [labels, z1]
    gen1_outputs = gen1(gen1_inputs)
    adv1_outputs = dis1(gen1_outputs)   [enc1(gen1_outputs)]
    # labels   z1 to prob labels are real   z1 recon   feature1 recon
    adv1 = Model(gen1_inputs, adv1_outputs, name="adv1")
    # loss functions: 1) prob labels are real (adversarial1 loss)
    # 2) Q network 1 loss (entropy1 loss)
    # 3) conditional1 loss (classifier error)
    loss_weights = [1.0, 1.0, 1.0]
    loss = ['binary_crossentropy',
            'mse',
            'categorical_crossentropy']
    adv1.compile(loss=loss,
                 loss_weights=loss_weights,
                 optimizer=optimizer,
                 metrics=['accuracy'])
    adv1.summary() 
代码语言:javascript复制
 # train discriminator and adversarial networks
    models = (enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1)
    params = (batch_size, train_steps, num_labels, z_dim, model_name)
    train(models, data, params) 

最后,训练函数与典型的 GAN 训练相似,不同之处在于我们一次只训练一个 GAN(即GAN[0]然后是GAN[0])。 代码显示在“列表 6.2.5”中。 值得注意的是,训练顺序为:

  1. Discriminator[1]Q[1]网络通过最小化判别器和熵损失
  2. Discriminator[0]Q[0]网络通过最小化判别器和熵损失
  3. Adversarial[1]网络通过最小化对抗性,熵和条件损失
  4. Adversarial[0]网络通过最小化对抗性,熵和条件损失

“列表 6.2.5”:stackedgan-mnist-6.2.1.py

代码语言:javascript复制
def train(models, data, params):
    """Train the discriminator and adversarial Networks 
代码语言:javascript复制
 Alternately train discriminator and adversarial networks by batch.
    Discriminator is trained first with real and fake images,
    corresponding one-hot labels and latent codes.
    Adversarial is trained next with fake images pretending
    to be real, corresponding one-hot labels and latent codes.
    Generate sample images per save_interval. 
代码语言:javascript复制
 # Arguments
        models (Models): Encoder, Generator, Discriminator,
            Adversarial models
        data (tuple): x_train, y_train data
        params (tuple): Network parameters 
代码语言:javascript复制
 """
    # the StackedGAN and Encoder models
    enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1 = models
    # network parameters
    batch_size, train_steps, num_labels, z_dim, model_name = params
    # train dataset
    (x_train, y_train), (_, _) = data
    # the generator image is saved every 500 steps
    save_interval = 500 
代码语言:javascript复制
 # label and noise codes for generator testing
    z0 = np.random.normal(scale=0.5, size=[16, z_dim])
    z1 = np.random.normal(scale=0.5, size=[16, z_dim])
    noise_class = np.eye(num_labels)[np.arange(0, 16) % num_labels]
    noise_params = [noise_class, z0, z1]
    # number of elements in train dataset
    train_size = x_train.shape[0]
    print(model_name,
          "Labels for generated images: ",
          np.argmax(noise_class, axis=1)) 
代码语言:javascript复制
 for i in range(train_steps):
        # train the discriminator1 for 1 batch
        # 1 batch of real (label=1.0) and fake feature1 (label=0.0)
        # randomly pick real images from dataset
        rand_indexes = np.random.randint(0,
                                         train_size,
                                         size=batch_size)
        real_images = x_train[rand_indexes]
        # real feature1 from encoder0 output
        real_feature1 = enc0.predict(real_images)
        # generate random 50-dim z1 latent code
        real_z1 = np.random.normal(scale=0.5,
                                   size=[batch_size, z_dim])
        # real labels from dataset
        real_labels = y_train[rand_indexes] 
代码语言:javascript复制
 # generate fake feature1 using generator1 from
        # real labels and 50-dim z1 latent code
        fake_z1 = np.random.normal(scale=0.5,
                                   size=[batch_size, z_dim])
        fake_feature1 = gen1.predict([real_labels, fake_z1]) 
代码语言:javascript复制
 # real   fake data
        feature1 = np.concatenate((real_feature1, fake_feature1))
        z1 = np.concatenate((fake_z1, fake_z1)) 
代码语言:javascript复制
 # label 1st half as real and 2nd half as fake
        y = np.ones([2 * batch_size, 1])
        y[batch_size:, :] = 0 
代码语言:javascript复制
 # train discriminator1 to classify feature1 as
        # real/fake and recover
        # latent code (z1). real = from encoder1,
        # fake = from genenerator1
        # joint training using discriminator part of
        # advserial1 loss and entropy1 loss
        metrics = dis1.train_on_batch(feature1, [y, z1])
        # log the overall loss only
        log = "%d: [dis1_loss: %f]" % (i, metrics[0]) 
代码语言:javascript复制
 # train the discriminator0 for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=0.0)
        # generate random 50-dim z0 latent code
        fake_z0 = np.random.normal(scale=0.5, size=[batch_size, z_dim])
        # generate fake images from real feature1 and fake z0
        fake_images = gen0.predict([real_feature1, fake_z0])
        # real   fake data
        x = np.concatenate((real_images, fake_images))
        z0 = np.concatenate((fake_z0, fake_z0))
        # train discriminator0 to classify image 
        # as real/fake and recover latent code (z0)
        # joint training using discriminator part of advserial0 loss
        # and entropy0 loss
        metrics = dis0.train_on_batch(x, [y, z0])
        # log the overall loss only (use dis0.metrics_names)
        log = "%s [dis0_loss: %f]" % (log, metrics[0]) 
代码语言:javascript复制
 # adversarial training 
        # generate fake z1, labels
        fake_z1 = np.random.normal(scale=0.5,
                                   size=[batch_size, z_dim])
        # input to generator1 is sampling fr real labels and
        # 50-dim z1 latent code
        gen1_inputs = [real_labels, fake_z1] 
代码语言:javascript复制
 # label fake feature1 as real
        y = np.ones([batch_size, 1]) 
代码语言:javascript复制
 # train generator1 (thru adversarial) by fooling i
        # the discriminator
        # and approximating encoder1 feature1 generator
        # joint training: adversarial1, entropy1, conditional1
        metrics = adv1.train_on_batch(gen1_inputs,
                                      [y, fake_z1, real_labels])
        fmt = "%s [adv1_loss: %f, enc1_acc: %f]"
        # log the overall loss and classification accuracy
        log = fmt % (log, metrics[0], metrics[6]) 
代码语言:javascript复制
 # input to generator0 is real feature1 and
        # 50-dim z0 latent code
        fake_z0 = np.random.normal(scale=0.5,
                                   size=[batch_size, z_dim])
        gen0_inputs = [real_feature1, fake_z0] 
代码语言:javascript复制
 # train generator0 (thru adversarial) by fooling
        # the discriminator and approximating encoder1 imag 
        # source generator joint training:
        # adversarial0, entropy0, conditional0
        metrics = adv0.train_on_batch(gen0_inputs,
                                      [y, fake_z0, real_feature1])
        # log the overall loss only
        log = "%s [adv0_loss: %f]" % (log, metrics[0]) 
代码语言:javascript复制
 print(log)
        if (i   1) % save_interval == 0:
            generators = (gen0, gen1)
            plot_images(generators,
                        noise_params=noise_params,
                        show=False,
                        step=(i   1),
                        model_name=model_name) 
代码语言:javascript复制
 # save the modelis after training generator0 & 1
    # the trained generator can be reloaded for
    # future MNIST digit generation
    gen1.save(model_name   "-gen1.h5")
    gen0.save(model_name   "-gen0.h5") 

tf.keras中 StackedGAN 的代码实现现已完成。 训练后,可以评估生成器的输出以检查合成 MNIST 数字的某些属性是否可以以与我们在 InfoGAN 中所做的类似的方式进行控制。

StackedGAN 的生成器输出

在对 StackedGAN 进行 10,000 步训练之后,Generator[0]Generator[1]模型被保存在文件中。 Generator[0]Generator[1]堆叠在一起可以合成以标签和噪声代码z[0]z[1]为条件的伪造图像。

StackedGAN 生成器可以通过以下方式进行定性验证:

从两个噪声代码z[0]z[1]的离散标签从 0 变到 9,从正态分布中采样,均值为 0.5,标准差为 1.0。 结果显示在“图 6.2.9”中。 我们可以看到 StackedGAN 离散代码可以控制生成器生成的数字:

代码语言:javascript复制
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --digit=0 

代码语言:javascript复制
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --digit=9 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KUD8zxRt-1681704311652)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_16.png)]

图 6.2.9:当离散代码从 0 变为 9 时,StackedGAN 生成的图像。z0z1均从正态分布中采样,平均值为 0,标准差为 0.5。

如下所示,将第一噪声码z[0]从 -4.0 到 4.0 的恒定向量变为从 0 到 9 的数字。 第二噪声代码z[1]被设置为零向量。 “图 6.2.10”显示第一个噪声代码控制数字的粗细。 例如,对于数字 8:

代码语言:javascript复制
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --z0=0 --z1=0 --p0 --digit=8 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A6w2h8Wp-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_17.png)]

图 6.2.10:使用 StackedGAN 作为第一个噪声代码z0生成的图像,对于数字 0 到 9,其向量从 -4.0 到 4.0 不变。z0似乎控制着每个数字的粗细。

如下所示,对于数字 0 到 9,从 -1.0 到 1.0 的恒定向量变化第二噪声代码z[1]。 将第一噪声代码z[0]设置为零向量。“图 6.2.11”显示第二个噪声代码控制旋转(倾斜),并在一定程度上控制手指的粗细。 例如,对于数字 8:

代码语言:javascript复制
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --z0=0 --z1=0 --p1 --digit=8 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-363w44bd-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_18.png)]

图 6.2.11:由 StackedGAN 生成的图像作为第二个噪声代码z1从 0 到 9 的恒定向量 -1.0 到 1.0 变化。z1似乎控制着每个数字的旋转(倾斜)和笔划粗细

“图 6.2.9”至“图 6.2.11”证明 StackedGAN 提供了对生成器输出属性的附加控制。 控件和属性为(标签,哪个数字),(z0,数字粗细)和(z1,数字倾斜度)。 从此示例中,我们可以控制其他可能的实验,例如:

  • 从当前数量 2 增加栈中的元素数量
  • 像在 InfoGAN 中一样,减小代码z[0]z[1]的尺寸

“图 6.2.12”显示了 InfoGAN 和 StackedGAN 的潜在代码之间的区别:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-S88zFsW8-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_19.png)]

图 6.2.12:不同 GAN 的潜在表示

解开代码的基本思想是对损失函数施加约束,以使仅特定属性受代码影响。 从结构上讲,与 StackedGAN 相比,InfoGAN 更易于实现。 InfoGAN 的训练速度也更快。

4. 总结

在本章中,我们讨论了如何解开 GAN 的潜在表示。 在本章的前面,我们讨论了 InfoGAN 如何最大化互信息以迫使生成器学习解纠缠的潜向量。 在 MNIST 数据集示例中,InfoGAN 使用三种表示形式和一个噪声代码作为输入。 噪声以纠缠的形式表示其余的属性。 StackedGAN 以不同的方式处理该问题。 它使用一堆编码器 GAN 来学习如何合成伪造的特征和图像。 首先对编码器进行训练,以提供特征数据集。 然后,对编码器 GAN 进行联合训练,以学习如何使用噪声代码控制生成器输出的属性。

在下一章中,我们将着手一种新型的 GAN,它能够在另一个域中生成新数据。 例如,给定马的图像,GAN 可以将其自动转换为斑马的图像。 这种 GAN 的有趣特征是无需监督即可对其进行训练,并且不需要成对的样本数据。

5. 参考

  1. Xi Chen et al.: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets. Advances in Neural Information Processing Systems, 2016 (http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf).
  2. Xun Huang et al. Stacked Generative Adversarial Networks. IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Vol. 2, 2017 (http://openaccess.thecvf.com/content_cvpr_2017/papers/Huang_Stacked_Generative_Adversarial_CVPR_2017_paper.pdf).

七、跨域 GAN

在计算机视觉,计算机图形学和图像处理中,许多任务涉及将图像从一种形式转换为另一种形式。 灰度图像的着色,将卫星图像转换为地图,将一位艺术家的艺术品风格更改为另一位艺术家,将夜间图像转换为白天,将夏季照片转换为冬天只是几个例子。 这些任务被称为跨域迁移,将成为本章的重点。 源域中的图像将迁移到目标域,从而生成新的转换图像。

跨域迁移在现实世界中具有许多实际应用。 例如,在自动驾驶研究中,收集公路现场驾驶数据既费时又昂贵。 为了在该示例中覆盖尽可能多的场景变化,将在不同的天气条件,季节和时间中遍历道路,从而为我们提供了大量不同的数据。 使用跨域迁移,可以通过转换现有图像来生成看起来真实的新合成场景。 例如,我们可能只需要在夏天从一个区域收集道路场景,在冬天从另一地方收集道路场景。 然后,我们可以将夏季图像转换为冬季,并将冬季图像转换为夏季。 在这种情况下,它将必须完成的任务数量减少了一半。

现实的合成图像的生成是 GAN 擅长的领域。 因此,跨域翻译是 GAN 的应用之一。 在本章中,我们将重点介绍一种流行的跨域 GAN 算法,称为 CycleGAN [2]。 与其他跨域迁移算法(例如 pix2pix [3])不同,CycleGAN 不需要对齐的训练图像即可工作。 在对齐的图像中,训练数据应该是由源图像及其对应的目标图像组成的一对图像; 例如,卫星图像和从该图像得出的相应地图。

CycleGAN 仅需要卫星数据图像和地图。 这些地图可以来自其他卫星数据,而不必事先从训练数据中生成。

在本章中,我们将探讨以下内容:

  • CycleGAN 的原理,包括其在tf.keras中的实现
  • CycleGAN 的示例应用,包括使用 CIFAR10 数据集对灰度图像进行着色和应用于 MNIST 数字和街景门牌号码(SVHN) [1]数据集的样式迁移

让我们开始讨论 CycleGAN 背后的原理。

1. CycleGAN 的原理

将图像从一个域转换到另一个域是计算机视觉,计算机图形学和图像处理中的常见任务。“图 7.1.1”显示了边缘检测,这是常见的图像转换任务:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DKJSNQev-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_01.png)]

图 7.1.1:对齐图像对的示例:使用 Canny 边缘检测器的左,原始图像和右,变换后的图像。 原始照片是作者拍摄的。

在此示例中,我们可以将真实照片(左)视为源域中的图像,将边缘检测的照片(右)视为目标域中的样本。 还有许多其他具有实际应用的跨域翻译过程,例如:

  • 卫星图像到地图
  • 脸部图像到表情符号,漫画或动画
  • 身体图像到头像
  • 灰度照片的着色
  • 医学扫描到真实照片
  • 真实照片到画家的绘画

在不同领域中还有许多其他示例。 例如,在计算机视觉和图像处理中,我们可以通过发明一种从源图像中提取特征并将其转换为目标图像的算法来执行翻译。 坎尼边缘算子就是这种算法的一个例子。 但是,在很多情况下,翻译对于手工工程师而言非常复杂,因此几乎不可能找到合适的算法。 源域分布和目标域分布都是高维且复杂的。

解决图像翻译问题的一种方法是使用深度学习技术。 如果我们具有来自源域和目标域的足够大的数据集,则可以训练神经网络对转换进行建模。 由于必须在给定源图像的情况下自动生成目标域中的图像,因此它们必须看起来像是来自目标域的真实样本。 GAN 是适合此类跨域任务的网络。 pix2pix [3]算法是跨域算法的示例。

pix2pix 算法与条件 GANCGAN)[4]相似,我们在“第 4 章”,“生成对抗网络(GAN)”。 我们可以回想起在 CGAN 中,除了z噪声输入之外,诸如单热向量之类的条件会限制生成器的输出。 例如,在 MNIST 数字中,如果我们希望生成器输出数字 8,则条件为单热向量[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]。 在 pix2pix 中,条件是要翻译的图像。 生成器的输出是翻译后的图像。 通过优化 CGAN 损失来训练 pix2pix 算法。 为了使生成的图像中的模糊最小化,还包括 L1 损失。

类似于 pix2pix 的神经网络的主要缺点是训练输入和输出图像必须对齐。“图 7.1.1”是对齐的图像对的示例。 样本目标图像是从源生成的。 在大多数情况下,对齐的图像对不可用或无法从源图像生成,也不昂贵,或者我们不知道如何从给定的源图像生成目标图像。 我们拥有的是来自源域和目标域的样本数据。“图 7.1.2”是来自同一向日葵主题上源域(真实照片)和目标域(范高的艺术风格)的数据示例。 源图像和目标图像不一定对齐。

与 pix2pix 不同,CycleGAN 会学习图像翻译,只要源数据和目标数据之间有足够的数量和差异即可。 无需对齐。 CycleGAN 学习源和目标分布,以及如何从给定的样本数据中将源分布转换为目标分布。 无需监督。 在“图 7.1.2”的上下文中,我们只需要数千张真实向日葵的照片和数千张梵高向日葵画的照片。 在训练了 CycleGAN 之后,我们可以将向日葵的照片转换成梵高的画作:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wFIkXXou-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_02.png)]

图 7.1.2:未对齐的图像对示例:左侧为菲律宾大学沿着大学大道的真实向日葵照片,右侧为伦敦国家美术馆的梵高的向日葵, 英国。 原始照片由作者拍摄。

下一个问题是:我们如何建立可以从未配对数据中学习的模型? 在下一部分中,我们将构建一个使用正向和反向循环 GAN 的 CycleGAN,以及一个循环一致性检查,以消除对配对输入数据的需求。

CycleGAN 模型

“图 7.1.3”显示了 CycleGAN 的网络模型:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kiIRicws-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_03.png)]

图 7.1.3:CycleGAN 模型包含四个网络:生成器G,生成器F,判别器D[y]和判别器D[x]

让我们逐个讨论“图 7.1.3”。 让我们首先关注上层网络,即转发周期 GAN。 如下图“图 7.1.4”所示,正向循环 CycleGAN 的目标是学习以下函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gsjmUDa2-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_001.png)] (Equation 7.1.1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NOQ9T982-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_04.png)]

图 7.1.4:伪造y的 CycleGAN 生成器G

“公式 7.1.1”只是假目标数据y'的生成器G。 它将数据从源域x转换为目标域y

要训​​练生成器,我们必须构建 GAN。 这是正向循环 GAN,如图“图 7.1.5”所示。 该图表明,它类似于“第 4 章”,“生成对抗网络(GANs)”中的典型 GAN,由生成器G和判别器D[y]组成,它可以以相同的对抗方式进行训练。通过仅利用源域中的可用实际图像x和目标域中的实际图像y,进行无监督学习。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SC3RgcGJ-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_05.png)]

图 7.1.5:CycleGAN 正向循环 GAN

与常规 GAN 不同,CycleGAN 施加了周期一致性约束,如图“图 7.1.6”所示。 前向循环一致性网络可确保可以从伪造的目标数据中重建真实的源数据:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CBQIl6q7-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_004.png)] (Equation 7.1.2)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QlWNFqm4-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_06.png)]

图 7.1.6:CycleGAN 循环一致性检查

通过最小化正向循环一致性 L1 损失来完成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ekYw53po-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_005.png)] (Equation 7.1.3)

周期一致性损失使用 L1平均绝对误差MAE),因为与 L2均方误差MSE)相比,它通常导致较少的模糊图像重建。

循环一致性检查表明,尽管我们已将源数据x转换为域y,但x的原始特征仍应保留在y中并且可恢复。 网络F只是我们将从反向循环 GAN 借用的另一个生成器,如下所述。

CycleGAN 是对称的。 如图“图 7.1.7”所示,后向循环 GAN 与前向循环 GAN 相同,但将源数据x和目标数据y的作用逆转。 现在,源数据为y,目标数据为x。 生成器GF的作用也相反。F现在是生成器,而G恢复输入。 在正向循环 GAN 中,生成器F是用于恢复源数据的网络,而G是生成器。

Backward Cycle GAN 生成器的目标是合成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8l4O360i-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_006.png)] (Equation 7.1.2)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fLcG2jW4-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_07.png)]

图 7.1.7:CycleGAN 向后循环 GAN

这可以通过对抗性训练反向循环 GAN 来完成。 目的是让生成器F学习如何欺骗判别器D[x]

此外,还具有类似的向后循环一致性,以恢复原始源y

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mvyIpDFe-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_008.png)] (Equation 7.1.4)

这是通过最小化后向循环一致性 L1 损失来完成的:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rxcsZxBP-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_009.png)] (Equation 7.1.5)

总而言之,CycleGAN 的最终目标是使生成器G学习如何合成伪造的目标数据y',该伪造的目标数据y'会在正向循环中欺骗识别器D[y]。 由于网络是对称的,因此 CycleGAN 还希望生成器F学习如何合成伪造的源数据x',该伪造的源数据可以使判别器D[x]在反向循环中蒙蔽。 考虑到这一点,我们现在可以将所有损失函数放在一起。

让我们从 GAN 部分开始。 受到最小二乘 GAN(LSGAN) [5]更好的感知质量的启发,如“第 5 章”,“改进的 GAN” 中所述,CycleGAN 还使用 MSE 作为判别器和生成器损失。 回想一下,LSGAN 与原始 GAN 之间的差异需要使用 MSE 损失,而不是二进制交叉熵损失。

CycleGAN 将生成器-标识符损失函数表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YumlqfLU-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_014.png)] (Equation 7.1.6)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vcjHdscP-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_015.png)] (Equation 7.1.7)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D1jqJXMo-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_016.png)] (Equation 7.1.8)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oXpqJ2Ij-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_017.png)] (Equation 7.1.9)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fIfSu1cM-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_018.png)] (Equation 7.1.10)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VK9QUtYW-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_019.png)] (Equation 7.1.11)

损失函数的第二组是周期一致性损失,可以通过汇总前向和后向 GAN 的贡献来得出:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1dxsG5PQ-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_020.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-apnJZhCA-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_021.png)] (Equation 7.1.12)

CycleGAN 的总损失为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3TdRzKpk-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_022.png)] (Equation 7.1.13)

CycleGAN 建议使用以下权重值λ1 = 1.0λ2 = 10.0,以更加重视循环一致性检查。

训练策略类似于原始 GAN。 “算法 7.1.1”总结了 CycleGAN 训练过程。

“算法 7.1.1”:CycleGAN 训练

n训练步骤重复上述步骤:

  1. 通过使用真实的源数据和目标数据训练前向循环判别器,将L_forward_GAN^(D)降至最低。 实际目标数据的小批量y标记为 1.0。 伪造的目标数据y' = G(x)的小批量标记为 0.0。
  2. 通过使用真实的源数据和目标数据训练反向循环判别器,将L_backward_GAN^(D)最小化。 实际源数据的小批量x标记为 1.0。 一小部分伪造的源数据x' = F(y)被标记为 0.0。
  3. 通过训练对抗网络中的前向周期和后向周期生成器,将L_GAN^(D)L_cyc最小化。 伪造目标数据的一个小批量y' = G(x)被标记为 1.0。 一小部分伪造的源数据x' = F(y)被标记为 1.0。 判别器的权重被冻结。

在神经样式迁移问题中,颜色组合可能无法成功地从源图像迁移到伪造目标图像。 此问题显示在“图 7.1.8”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TtNbGPMZ-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_08.png)]

图 7.1.8:在样式迁移过程中,颜色组合可能无法成功迁移。 为了解决此问题,将恒等损失添加到总损失函数中

为了解决这个问题,CycleGAN 建议包括正向和反向循环身份损失函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-akghAJMo-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_033.png)] (Equation 7.1.14)

CycleGAN 的总损失变为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-abCvM3he-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_034.png)] (Equation 7.1.15)

其中λ3 = 0.5。 在对抗训练中,身份损失也得到了优化。“图 7.1.9”重点介绍了实现身份正则器的 CycleGAN 辅助网络:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jn1AhAKg-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_09.png)]

图 7.1.9:具有身份正则化网络的 CycleGAN 模型,图像左侧突出显示

在下一个部分,我们将在tf.keras中实现 CycleGAN。

使用 Keras 实现 CycleGAN

我们来解决,这是 CycleGAN 可以解决的简单问题。 在“第 3 章”,“自编码器”中,我们使用了自编码器为 CIFAR10 数据集中的灰度图像着色。 我们可以记得,CIFAR10 数据集包含 50,000 个训练过的数据项和 10,000 个测试数据样本,这些样本属于 10 个类别的32 x 32 RGB 图像。 我们可以使用rgb2gray(RGB)将所有彩色图像转换为灰度图像,如“第 3 章”,“自编码器”中所述。

接下来,我们可以将灰度训练图像用作源域图像,将原始彩色图像用作目标域图像。 值得注意的是,尽管数据集是对齐的,但我们 CycleGAN 的输入是彩色图像的随机样本和灰度图像的随机样本。 因此,我们的 CycleGAN 将看不到训练数据对齐。 训练后,我们将使用测试的灰度图像来观察 CycleGAN 的表现。

如前几节所述,要实现 CycleGAN,我们需要构建两个生成器和两个判别器。 CycleGAN 的生成器学习源输入分布的潜在表示,并将该表示转换为目标输出分布。 这正是自编码器的功能。 但是,类似于“第 3 章”,“自编码器”中讨论的典型自编码器,使用的编码器会对输入进行下采样,直到瓶颈层为止,此时解码器中的处理过程相反。

由于在编码器和解码器层之间共享许多低级特征,因此该结构不适用于某些图像转换问题。 例如,在着色问题中,灰度图像的形式,结构和边缘与彩色图像中的相同。 为了解决这个问题,CycleGAN 生成器使用 U-Net [7]结构,如图“图 7.1.10”所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YCKH97o8-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_10.png)]

图 7.1.10:在 Keras 中实现正向循环生成器G。 产生器是包括编码器和解码器的 U 网络[7]。

在 U-Net 结构中,编码器层的输出e[ni]与解码器层的输出d[i],其中n = 4是编码器/解码器的层数,i = 1, 2, 3是共享信息的层号。

我们应该注意,尽管该示例使用n = 4,但输入/输出尺寸较大的问题可能需要更深的编码器/解码器层。 通过 U-Net 结构,可以在编码器和解码器之间自由迁移特征级别的信息。

编码器层由Instance Normalization(IN)-LeakyReLU-Conv2D组成,而解码器层由IN-ReLU-Conv2D组成。 编码器/解码器层的实现如清单 7.1.1 所示,而生成器的实现如列表 7.1.2 所示。

完整的代码可在 GitHub 上找到。

实例规范化IN)是每个数据(即 IN 是图像或每个特征的 BN)。 在样式迁移中,重要的是标准化每个样本而不是每个批量的对比度。 IN 等于,相当于对比度归一化。 同时,BN 打破了对比度标准化。

记住在使用 IN 之前先安装tensorflow-addons

代码语言:javascript复制
$ pip install tensorflow-addons 

“列表 7.1.1”:cyclegan-7.1.1.py

代码语言:javascript复制
def encoder_layer(inputs,
                  filters=16,
                  kernel_size=3,
                  strides=2,
                  activation='relu',
                  instance_norm=True):
    """Builds a generic encoder layer made of Conv2D-IN-LeakyReLU
    IN is optional, LeakyReLU may be replaced by ReLU
    """ 
代码语言:javascript复制
 conv = Conv2D(filters=filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same') 
代码语言:javascript复制
 x = inputs
    if instance_norm:
        x = InstanceNormalization(axis=3)(x)
    if activation == 'relu':
        x = Activation('relu')(x)
    else:
        x = LeakyReLU(alpha=0.2)(x)
    x = conv(x)
    return x 
代码语言:javascript复制
def decoder_layer(inputs,
                  paired_inputs,
                  filters=16,
                  kernel_size=3,
                  strides=2,
                  activation='relu',
                  instance_norm=True):
    """Builds a generic decoder layer made of Conv2D-IN-LeakyReLU
    IN is optional, LeakyReLU may be replaced by ReLU
    Arguments: (partial)
    inputs (tensor): the decoder layer input
    paired_inputs (tensor): the encoder layer output 
          provided by U-Net skip connection &
          concatenated to inputs.
    """ 
代码语言:javascript复制
 conv = Conv2DTranspose(filters=filters,
                           kernel_size=kernel_size,
                           strides=strides,
                           padding='same') 
代码语言:javascript复制
 x = inputs
    if instance_norm:
        x = InstanceNormalization(axis=3)(x)
    if activation == 'relu':
        x = Activation('relu')(x)
    else:
        x = LeakyReLU(alpha=0.2)(x)
    x = conv(x)
    x = concatenate([x, paired_inputs])
    return x 

将移至生成器实现中:

“列表 7.1.2”:cyclegan-7.1.1.py

Keras 中的生成器实现:

代码语言:javascript复制
def build_generator(input_shape,
                    output_shape=None,
                    kernel_size=3,
                    name=None):
    """The generator is a U-Network made of a 4-layer encoder
    and a 4-layer decoder. Layer n-i is connected to layer i. 
代码语言:javascript复制
 Arguments:
    input_shape (tuple): input shape
    output_shape (tuple): output shape
    kernel_size (int): kernel size of encoder & decoder layers
    name (string): name assigned to generator model 
代码语言:javascript复制
 Returns:
    generator (Model):
    """ 
代码语言:javascript复制
 inputs = Input(shape=input_shape)
    channels = int(output_shape[-1])
    e1 = encoder_layer(inputs,
                       32,
                       kernel_size=kernel_size,
                       activation='leaky_relu',
                       strides=1)
    e2 = encoder_layer(e1,
                       64,
                       activation='leaky_relu',
                       kernel_size=kernel_size)
    e3 = encoder_layer(e2,
                       128,
                       activation='leaky_relu',
                       kernel_size=kernel_size)
    e4 = encoder_layer(e3,
                       256,
                       activation='leaky_relu',
                       kernel_size=kernel_size) 
代码语言:javascript复制
 d1 = decoder_layer(e4,
                       e3,
                       128,
                       kernel_size=kernel_size)
    d2 = decoder_layer(d1,
                       e2,
                       64,
                       kernel_size=kernel_size)
    d3 = decoder_layer(d2,
                       e1,
                       32,
                       kernel_size=kernel_size)
    outputs = Conv2DTranspose(channels,
                              kernel_size=kernel_size,
                              strides=1,
                              activation='sigmoid',
                              padding='same')(d3) 
代码语言:javascript复制
 generator = Model(inputs, outputs, name=name) 
代码语言:javascript复制
 return generator 

CycleGAN 的判别器类似于原始 GAN 判别器。 输入图像被下采样数次(在此示例中为 3 次)。 最后一层是Dense(1)层,它预测输入为实数的可能性。 除了不使用 IN 之外,每个层都类似于生成器的编码器层。 然而,在大图像中,用一个数字将图像计算为真实图像或伪图像会导致参数效率低下,并导致生成器的图像质量较差。

解决方案是使用 PatchGAN [6],该方法将图像划分为补丁网格,并使用标量值网格来预测补丁是真实概率。“图 7.1.11”显示了原始 GAN 判别器和2 x 2 PatchGAN 判别器之间的比较:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-89c04LPV-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_11.png)]

图 7.1.11:GAN 与 PatchGAN 判别器的比较

在此示例中,面片不重叠且在其边界处相遇。 但是,通常,补丁可能会重叠。

我们应该注意,PatchGAN 并没有在 CycleGAN 中引入一种新型的 GAN。 为了提高生成的图像质量,如果使用2 x 2 PatchGAN,则没有四个输出可以区分,而没有一个输出可以区分。 损失函数没有变化。 从直觉上讲,这是有道理的,因为如果图像的每个面片或部分看起来都是真实的,则整个图像看起来会更加真实。

“图 7.1.12”显示了tf.keras中实现的判别器网络。 下图显示了判别器确定输入图像或色块为彩色 CIFAR10 图像的可能性:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aQ9IgfTU-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_12.png)]

图 7.1.12:目标标识符D[y]tf.keras中的实现。 PatchGAN 判别器显示在右侧

由于输出图像只有32 x 32 RGB 时较小,因此表示该图像是真实的单个标量就足够了。 但是,当使用 PatchGAN 时,我们也会评估结果。“列表 7.1.3”显示了判别器的函数构建器:

“列表 7.1.3”:cyclegan-7.1.1.py

tf.keras中的判别器实现:

代码语言:javascript复制
def build_discriminator(input_shape,
                        kernel_size=3,
                        patchgan=True,
                        name=None):
    """The discriminator is a 4-layer encoder that outputs either
    a 1-dim or a n x n-dim patch of probability that input is real 
代码语言:javascript复制
 Arguments:
    input_shape (tuple): input shape
    kernel_size (int): kernel size of decoder layers
    patchgan (bool): whether the output is a patch 
        or just a 1-dim
    name (string): name assigned to discriminator model 
代码语言:javascript复制
 Returns:
    discriminator (Model):
    """ 
代码语言:javascript复制
 inputs = Input(shape=input_shape)
    x = encoder_layer(inputs,
                      32,
                      kernel_size=kernel_size,
                      activation='leaky_relu',
                      instance_norm=False)
    x = encoder_layer(x,
                      64,
                      kernel_size=kernel_size,
                      activation='leaky_relu',
                      instance_norm=False)
    x = encoder_layer(x,
                      128,
                      kernel_size=kernel_size,
                      activation='leaky_relu',
                      instance_norm=False)
    x = encoder_layer(x,
                      256,
                      kernel_size=kernel_size,
                      strides=1,
                      activation='leaky_relu',
                      instance_norm=False) 
代码语言:javascript复制
 # if patchgan=True use nxn-dim output of probability
    # else use 1-dim output of probability
    if patchgan:
        x = LeakyReLU(alpha=0.2)(x)
        outputs = Conv2D(1,
                         kernel_size=kernel_size,
                         strides=2,
                         padding='same')(x)
    else:
        x = Flatten()(x)
        x = Dense(1)(x)
        outputs = Activation('linear')(x) 
代码语言:javascript复制
 discriminator = Model(inputs, outputs, name=name) 
代码语言:javascript复制
 return discriminator 

使用生成器和判别器生成器,我们现在可以构建 CycleGAN。“列表 7.1.4”显示了构建器函数。 与上一节中的讨论一致,实例化了两个生成器g_source = Fg_target = G以及两个判别器d_source = D[x]d_target = D[y]。 正向循环为x' = F(G(x)) = reco_source = g_source(g_target(source_input))。反向循环为y' = G(F(y)) = reco_target = g_target(g_source (target_input))

对抗模型的输入是源数据和目标数据,而输出是D[x]D[y]的输出以及重构的输入x'y'。 在本示例中,由于由于灰度图像和彩色图像中通道数之间的差异,因此未使用身份网络。 对于 GAN 和循环一致性损失,我们分别使用建议的λ1 = 1.0λ2 = 10.0损失权重。 与前几章中的 GAN 相似,我们使用 RMSprop 作为判别器的优化器,其学习率为2e-4,衰减率为6e-8。 对抗的学习率和衰退率是判别器的一半。

“列表 7.1.4”:cyclegan-7.1.1.py

tf.keras中的 CycleGAN 构建器:

代码语言:javascript复制
def build_cyclegan(shapes,
                   source_name='source',
                   target_name='target',
                   kernel_size=3,
                   patchgan=False,
                   identity=False
                   ):
    """Build the CycleGAN 
代码语言:javascript复制
 1) Build target and source discriminators
    2) Build target and source generators
    3) Build the adversarial network 
代码语言:javascript复制
 Arguments:
    shapes (tuple): source and target shapes
    source_name (string): string to be appended on dis/gen models
    target_name (string): string to be appended on dis/gen models
    kernel_size (int): kernel size for the encoder/decoder
        or dis/gen models
    patchgan (bool): whether to use patchgan on discriminator
    identity (bool): whether to use identity loss 
代码语言:javascript复制
 Returns:
    (list): 2 generator, 2 discriminator, 
        and 1 adversarial models 
    """ 
代码语言:javascript复制
 source_shape, target_shape = shapes
    lr = 2e-4
    decay = 6e-8
    gt_name = "gen_"   target_name
    gs_name = "gen_"   source_name
    dt_name = "dis_"   target_name
    ds_name = "dis_"   source_name 
代码语言:javascript复制
 # build target and source generators
    g_target = build_generator(source_shape,
                               target_shape,
                               kernel_size=kernel_size,
                               name=gt_name)
    g_source = build_generator(target_shape,
                               source_shape,
                               kernel_size=kernel_size,
                               name=gs_name)
    print('---- TARGET GENERATOR ----')
    g_target.summary()
    print('---- SOURCE GENERATOR ----')
    g_source.summary() 
代码语言:javascript复制
 # build target and source discriminators
    d_target = build_discriminator(target_shape,
                                   patchgan=patchgan,
                                   kernel_size=kernel_size,
                                   name=dt_name)
    d_source = build_discriminator(source_shape,
                                   patchgan=patchgan,
                                   kernel_size=kernel_size,
                                   name=ds_name)
    print('---- TARGET DISCRIMINATOR ----')
    d_target.summary()
    print('---- SOURCE DISCRIMINATOR ----')
    d_source.summary() 
代码语言:javascript复制
 optimizer = RMSprop(lr=lr, decay=decay)
    d_target.compile(loss='mse',
                     optimizer=optimizer,
                     metrics=['accuracy'])
    d_source.compile(loss='mse',
                     optimizer=optimizer,
                     metrics=['accuracy']) 
代码语言:javascript复制
 d_target.trainable = False
    d_source.trainable = False 
代码语言:javascript复制
 # build the computational graph for the adversarial model
    # forward cycle network and target discriminator
    source_input = Input(shape=source_shape)
    fake_target = g_target(source_input)
    preal_target = d_target(fake_target)
    reco_source = g_source(fake_target) 
代码语言:javascript复制
 # backward cycle network and source discriminator
    target_input = Input(shape=target_shape)
    fake_source = g_source(target_input)
    preal_source = d_source(fake_source)
    reco_target = g_target(fake_source) 
代码语言:javascript复制
 # if we use identity loss, add 2 extra loss terms
    # and outputs
    if identity:
        iden_source = g_source(source_input)
        iden_target = g_target(target_input)
        loss = ['mse', 'mse', 'mae', 'mae', 'mae', 'mae']
        loss_weights = [1., 1., 10., 10., 0.5, 0.5]
        inputs = [source_input, target_input]
        outputs = [preal_source,
                   preal_target,
                   reco_source,
                   reco_target,
                   iden_source,
                   iden_target]
    else:
        loss = ['mse', 'mse', 'mae', 'mae']
        loss_weights = [1., 1., 10., 10.]
        inputs = [source_input, target_input]
        outputs = [preal_source,
                   preal_target,
                   reco_source,
                   reco_target] 
代码语言:javascript复制
 # build adversarial model
    adv = Model(inputs, outputs, name='adversarial')
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    adv.compile(loss=loss,
                loss_weights=loss_weights,
                optimizer=optimizer,
                metrics=['accuracy'])
    print('---- ADVERSARIAL NETWORK ----')
    adv.summary() 
代码语言:javascript复制
 return g_source, g_target, d_source, d_target, adv 

我们遵循训练过程,我们可以从上一节中的“算法 7.1.1”中调用。“列表 7.1.5”显示了 CycleGAN 训练。 此训练与原始 GAN 之间的次要区别是有两个要优化的判别器。 但是,只有一种对抗模型需要优化。 对于每 2,000 步,生成器将保存预测的源图像和目标图像。 我们将的批量大小设为 32。我们也尝试了 1 的批量大小,但是输出质量几乎相同,并且需要花费更长的时间进行训练(批量为每个图像 43 ms,在 NVIDIA GTX 1060 上批量大小为 32 时,最大大小为每个图像 1 vs 3.6 ms)

“列表 7.1.5”:cyclegan-7.1.1.py

tf.keras中的 CycleGAN 训练例程:

代码语言:javascript复制
def train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   test_generator):
    """ Trains the CycleGAN. 

    1) Train the target discriminator
    2) Train the source discriminator
    3) Train the forward and backward cyles of 
        adversarial networks 
代码语言:javascript复制
 Arguments:
    models (Models): Source/Target Discriminator/Generator,
        Adversarial Model
    data (tuple): source and target training data
    params (tuple): network parameters
    test_params (tuple): test parameters
    test_generator (function): used for generating 
        predicted target and source images
    """ 
代码语言:javascript复制
 # the models
    g_source, g_target, d_source, d_target, adv = models
    # network parameters
    batch_size, train_steps, patch, model_name = params
    # train dataset
    source_data, target_data, test_source_data, test_target_data
            = data 
代码语言:javascript复制
 titles, dirs = test_params 
代码语言:javascript复制
 # the generator image is saved every 2000 steps
    save_interval = 2000
    target_size = target_data.shape[0]
    source_size = source_data.shape[0] 
代码语言:javascript复制
 # whether to use patchgan or not
    if patch > 1:
        d_patch = (patch, patch, 1)
        valid = np.ones((batch_size,)   d_patch)
        fake = np.zeros((batch_size,)   d_patch)
    else:
        valid = np.ones([batch_size, 1])
        fake = np.zeros([batch_size, 1]) 
代码语言:javascript复制
 valid_fake = np.concatenate((valid, fake))
    start_time = datetime.datetime.now() 
代码语言:javascript复制
 for step in range(train_steps):
        # sample a batch of real target data
        rand_indexes = np.random.randint(0,
                                         target_size,
                                         size=batch_size)
        real_target = target_data[rand_indexes] 
代码语言:javascript复制
 # sample a batch of real source data
        rand_indexes = np.random.randint(0,
                                         source_size,
                                         size=batch_size)
        real_source = source_data[rand_indexes]
        # generate a batch of fake target data fr real source data
        fake_target = g_target.predict(real_source) 
代码语言:javascript复制
 # combine real and fake into one batch
        x = np.concatenate((real_target, fake_target))
        # train the target discriminator using fake/real data
        metrics = d_target.train_on_batch(x, valid_fake)
        log = "%d: [d_target loss: %f]" % (step, metrics[0]) 
代码语言:javascript复制
 # generate a batch of fake source data fr real target data
        fake_source = g_source.predict(real_target)
        x = np.concatenate((real_source, fake_source))
        # train the source discriminator using fake/real data
        metrics = d_source.train_on_batch(x, valid_fake)
        log = "%s [d_source loss: %f]" % (log, metrics[0]) 
代码语言:javascript复制
 # train the adversarial network using forward and backward
        # cycles. the generated fake source and target 
        # data attempts to trick the discriminators
        x = [real_source, real_target]
        y = [valid, valid, real_source, real_target]
        metrics = adv.train_on_batch(x, y)
        elapsed_time = datetime.datetime.now() - start_time
        fmt = "%s [adv loss: %f] [time: %s]"
        log = fmt % (log, metrics[0], elapsed_time)
        print(log)
        if (step   1) % save_interval == 0:
            test_generator((g_source, g_target),
                           (test_source_data, test_target_data),
                           step=step 1,
                           titles=titles,
                           dirs=dirs,
                           show=False) 
代码语言:javascript复制
 # save the models after training the generators
    g_source.save(model_name   "-g_source.h5")
    g_target.save(model_name   "-g_target.h5") 

最后,在使用 CycleGAN 构建和训练函数之前,我们必须执行一些数据准备。 模块cifar10_utils.pyother_ utils.py加载CIFAR10训练和测试数据。 有关这两个文件的详细信息,请参考源代码。 加载后,将训练图像和测试图像转换为灰度,以生成源数据和测试源数据。

“列表 7.1.6”显示了 CycleGAN 如何用于构建和训练用于灰度图像着色的生成器网络(g_target)。 由于 CycleGAN 是对称的,因此我们还构建并训练了第二个生成器网络(g_source),该网络可以将颜色转换为灰度。 训练了两个 CycleGAN 着色网络。 第一种使用标量输出类似于原始 GAN 的判别器,第二种使用2 x 2 PatchGAN。

“列表 7.1.6”:cyclegan-7.1.1.py

CycleGAN 用于着色:

代码语言:javascript复制
def graycifar10_cross_colorcifar10(g_models=None):
    """Build and train a CycleGAN that can do
        grayscale <--> color cifar10 images
    """ 
代码语言:javascript复制
 model_name = 'cyclegan_cifar10'
    batch_size = 32
    train_steps = 100000
    patchgan = True
    kernel_size = 3
    postfix = ('%dp' % kernel_size) 
            if patchgan else ('%d' % kernel_size) 
代码语言:javascript复制
 data, shapes = cifar10_utils.load_data()
    source_data, _, test_source_data, test_target_data = data
    titles = ('CIFAR10 predicted source images.',
              'CIFAR10 predicted target images.',
              'CIFAR10 reconstructed source images.',
              'CIFAR10 reconstructed target images.')
    dirs = ('cifar10_source-%s' % postfix, 
            'cifar10_target-%s' % postfix) 
代码语言:javascript复制
 # generate predicted target(color) and source(gray) images
    if g_models is not None:
        g_source, g_target = g_models
        other_utils.test_generator((g_source, g_target),
                                   (test_source_data, 
                                           test_target_data),
                                   step=0,
                                   titles=titles,
                                   dirs=dirs,
                                   show=True)
        return 
代码语言:javascript复制
 # build the cyclegan for cifar10 colorization
    models = build_cyclegan(shapes,
                            "gray-%s" % postfix,
                            "color-%s" % postfix,
                            kernel_size=kernel_size,
                            patchgan=patchgan)
    # patch size is divided by 2^n since we downscaled the input
    # in the discriminator by 2^n (ie. we use strides=2 n times)
    patch = int(source_data.shape[1] / 2**4) if patchgan else 1
    params = (batch_size, train_steps, patch, model_name)
    test_params = (titles, dirs)
    # train the cyclegan
    train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   other_utils.test_generator) 

在的下一部分中,我们将检查 CycleGAN 的生成器输出以进行着色。

CycleGAN 的生成器输出

“图 7.1.13”显示 CycleGAN 的着色结果。 源图像来自测试数据集:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UdAhE17z-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_13.png)]

图 7.1.13:使用不同技术进行着色。 显示的是基本事实,使用自编码器的着色(第 3 章,自编码器),使用带有原始 GAN 判别器的 CycleGAN 进行着色,以及使用带有 PatchGAN 判别器的 CycleGAN 进行着色。 彩色效果最佳。 原始彩色照片可以在该书的 GitHub 存储库中找到。

为了进行比较,我们使用第 3 章,“自编码器”中描述的普通自编码器显示了地面真实情况和着色结果。 通常,所有彩色图像在感觉上都是可接受的。 总体而言,似乎每种着色技术都有自己的优点和缺点。 所有着色方法与天空和车辆的正确颜色不一致。

例如,平面背景(第三行,第二列)中的天空为白色。 自编码器没错,但是 CycleGAN 认为它是浅棕色或蓝色。

对于第六行第六列,暗海上的船天空阴沉,但自编码器将其涂成蓝色和蓝色,而 CycleGAN 将其涂成蓝色和白色,而没有 PatchGAN。 两种预测在现实世界中都是有意义的。 同时,使用 PatchGAN 对 CycleGAN 的预测与基本事实相似。 在倒数第二行和第二列上,没有方法能够预测汽车的红色。 在动物身上,CycleGAN 的两种口味都具有接近真实情况的颜色。

由于 CycleGAN 是对称的,因此它还能在给定彩色图像的情况下预测灰度图像。“图 7.1.14”显示了两个 CycleGAN 变体执行的颜色到灰度转换。 目标图像来自测试数据集。 除了某些图像的灰度阴影存在细微差异外,这些预测通常是准确的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9v2eWewg-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_14.png)]

图 7.1.14:颜色(来自图 7.1.9)到 CycleGAN 的灰度转换

要训​​练 CycleGAN 进行着色,命令是:

代码语言:javascript复制
python3 cyclegan-7.1.1.py -c 

读者可以使用带有 PatchGAN 的 CycleGAN 预训练模型来运行图像转换:

代码语言:javascript复制
python3 cyclegan-7.1.1.py --cifar10_g_source=cyclegan_cifar10-g_source.h5
--cifar10_g_target=cyclegan_cifar10-g_target.h5 

在本节中,我们演示了 CycleGAN 在着色上的一种实际应用。 在下一部分中,我们将在更具挑战性的数据集上训练 CycleGAN。 源域 MNIST 与目标域 SVHN 数据集有很大的不同[1]。

MNIST 和 SVHN 数据集上的 CycleGAN

我们现在要解决一个更具挑战性的问题。 假设我们使用 MNIST 灰度数字作为源数据,并且我们想从 SVHN [1]中借鉴样式,这是我们的目标数据。 每个域中的样本数据显示在“图 7.1.15”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fkmEFbTV-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_15.png)]

图 7.1.15:两个未对齐数据的不同域。 原始彩色照片可以在该书的 GitHub 存储库中找到。

我们可以重用上一节中讨论的 CycleGAN 的所有构建和训练函数,以执行样式迁移。 唯一的区别是,我们必须添加用于加载 MNIST 和 SVHN 数据的例程。 SVHN 数据集可在这个页面中找到。

我们介绍mnist_svhn_utils.py模块来帮助我们完成此任务。“列表 7.1.7”显示了针对跨域迁移的 CycleGAN 的初始化和训练。

CycleGAN 结构与上一部分相同,不同之处在于我们使用的核大小为 5,因为两个域完全不同。

“列表 7.1.7”:cyclegan-7.1.1.py

CycleGAN 用于 MNIST 和 SVHN 之间的跨域样式迁移:

代码语言:javascript复制
def mnist_cross_svhn(g_models=None):
    """Build and train a CycleGAN that can do mnist <--> svhn
    """ 
代码语言:javascript复制
 model_name = 'cyclegan_mnist_svhn'
    batch_size = 32
    train_steps = 100000
    patchgan = True
    kernel_size = 5
    postfix = ('%dp' % kernel_size) 
            if patchgan else ('%d' % kernel_size) 
代码语言:javascript复制
 data, shapes = mnist_svhn_utils.load_data()
    source_data, _, test_source_data, test_target_data = data
    titles = ('MNIST predicted source images.',
              'SVHN predicted target images.',
              'MNIST reconstructed source images.',
              'SVHN reconstructed target images.')
    dirs = ('mnist_source-%s' 
            % postfix, 'svhn_target-%s' % postfix) 
代码语言:javascript复制
 # generate predicted target(svhn) and source(mnist) images
    if g_models is not None:
        g_source, g_target = g_models
        other_utils.test_generator((g_source, g_target),
                                   (test_source_data, 
                                           test_target_data),
                                   step=0,
                                   titles=titles,
                                   dirs=dirs,
                                   show=True)
        return 
代码语言:javascript复制
 # build the cyclegan for mnist cross svhn
    models = build_cyclegan(shapes,
                            "mnist-%s" % postfix,
                            "svhn-%s" % postfix,
                            kernel_size=kernel_size,
                            patchgan=patchgan)
    # patch size is divided by 2^n since we downscaled the input
    # in the discriminator by 2^n (ie. we use strides=2 n times)
    patch = int(source_data.shape[1] / 2**4) if patchgan else 1
    params = (batch_size, train_steps, patch, model_name)
    test_params = (titles, dirs)
    # train the cyclegan
    train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   other_utils.test_generator) 

将 MNIST 从测试数据集迁移到 SVHN 的结果显示在“图 7.1.16”中。 生成的图像具有样式的 SVHN,但是数字未完全传送。 例如,在第四行上,数字 3、1 和 3 由 CycleGAN 进行样式化。

但是,在第三行中,不带有和带有 PatchGAN 的 CycleGAN 的数字 9、6 和 6 分别设置为 0、6、01、0、65 和 68:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MMIvwZQw-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_16.png)]

图 7.1.16:测试数据从 MNIST 域到 SVHN 的样式迁移。 原始彩色照片可以在该书的 GitHub 存储库中找到。

向后循环的结果为“图 7.1.17”中所示的。 在这种情况下,目标图像来自 SVHN 测试数据集。 生成的图像具有 MNIST 的样式,但是数字没有正确翻译。 例如,在第一行中,对于不带和带有 PatchGAN 的 CycleGAN,数字 5、2 和 210 分别被样式化为 7、7、8、3、3 和 1:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZTSndytm-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_17.png)]

图 7.1.17:测试数据从 SVHN 域到 MNIST 的样式迁移。 原始彩色照片可以在该书的 GitHub 存储库中找到。

在 PatchGAN 的情况下,假设预测的 MNIST 数字被限制为一位,则输出 1 是可以理解的。 有以某种方式正确的预测,例如在第二行中,不使用 PatchGAN 的 CycleGAN 将 SVHN 数字的最后三列 6、3 和 4 转换为 6、3 和 6。 但是,CycleGAN 两种版本的输出始终是个位数且可识别。

从 MNIST 到 SVHN 的转换中出现的问题称为“标签翻转”[8],其中源域中的数字转换为目标域中的另一个数字。 尽管 CycleGAN 的预测是周期一致的,但它们不一定是语义一致的。 在翻译过程中数字的含义会丢失。

为了解决这个问题, Hoffman [8]引入了一种改进的 CycleGAN,称为循环一致性对抗域自适应CyCADA)。 不同之处在于,附加的语义损失项可确保预测不仅周期一致,而且语义一致。

“图 7.1.18”显示 CycleGAN 在正向循环中重建 MNIST 数字。 重建的 MNIST 数字几乎与源 MNIST 数字相同:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gGwbiPvZ-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_18.png)]

图 7.1.18:带有 MNIST 上的 PatchGAN 的 CycleGAN(源)到 SVHN(目标)的前向周期。 重建的源类似于原始源。 原始彩色照片可以在该书的 GitHub 存储库中找到。

“图 7.1.19”显示了 CycleGAN 在向后周期中重构 SVHN 数字的过程:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bvUxt1if-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_19.png)]

图 7.1.19:带有 MNIST 上的 PatchGAN 的 CycleGAN 与 SVHN(目标)的反向循环。 重建的目标与原始目标并不完全相似。 原始彩色照片可以在该书的 GitHub 存储库中找到。

在“图 7.1.3”中,CycleGAN 被描述为具有周期一致性。 换句话说,给定源x,CycleGAN 将正向循环中的源重构为x'。 另外,在给定目标y的情况下,CycleGAN 在反向循环中将目标重构为y'

重建了许多目标图像。 有些数字显然是相同的,例如最后两列(3 和 4)中的第二行,而有些数字却是相同的但是模糊的,例如前两列列中的第一行(5 和 2)。 尽管样式仍像第二行一样,但在前两列(从 33 和 6 到 1 以及无法识别的数字)中,有些数字会转换为另一数字。

要将 MNIST 的 CycleGAN 训练为 SVHN,命令为:

代码语言:javascript复制
python3 cyclegan-7.1.1.py -m 

鼓励读者使用带有 PatchGAN 的 CycleGAN 预训练模型来运行图像翻译:

代码语言:javascript复制
python3 cyclegan-7.1.1.py --mnist_svhn_g_source=cyclegan_mnist_svhn-g_ source.h5 --mnist_svhn_g_target=cyclegan_mnist_svhn-g_target.h5 

到目前为止,我们只看到了 CycleGAN 的两个实际应用。 两者都在小型数据集上进行了演示,以强调可重复性的概念。 如本章前面所述,CycleGAN 还有许多其他实际应用。 我们在这里介绍的 CycleGAN 可以作为分辨率更高的图像转换的基础。

2. 总结

在本章中,我们讨论了 CycleGAN 作为可用于图像翻译的算法。 在 CycleGAN 中,源数据和目标数据不一定要对齐。 我们展示了两个示例,灰度 ↔ 颜色MNIST ↔ SVHN ,尽管 CycleGAN 可以执行许多其他可能的图像转换 。

在下一章中,我们将着手另一种生成模型,即变分自编码器VAE)。 VAE 具有类似的学习目标–如何生成新图像(数据)。 他们专注于学习建模为高斯分布的潜在向量。 我们将以有条件的 VAE 和解开 VAE 中的潜在表示形式来证明 GAN 解决的问题中的其他相似之处。

3. 参考

  1. Yuval Netzer et al.: Reading Digits in Natural Images with Unsupervised Feature Learning. NIPS workshop on deep learning and unsupervised feature learning. Vol. 2011. No. 2. 2011 (https://www-cs.stanford.edu/~twangcat/papers/nips2011_housenumbers.pdf).
  2. Zhu-Jun-Yan et al.: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017 (http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf).
  3. Phillip Isola et al.: Image-to-Image Translation with Conditional Adversarial Networks. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017 (http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf).
  4. Mehdi Mirza and Simon Osindero. Conditional Generative Adversarial Nets. arXiv preprint arXiv:1411.1784, 2014 (https://arxiv.org/pdf/1411.1784.pdf).
  5. Xudong Mao et al.: Least Squares Generative Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017 (http://openaccess.thecvf.com/content_ICCV_2017/papers/Mao_Least_Squares_Generative_ICCV_2017_paper.pdf).
  6. Chuan Li and Michael Wand. Precomputed Real-Time Texture Synthesis with Markovian Generative Adversarial Networks. European Conference on Computer Vision. Springer, Cham, 2016 (https://arxiv.org/pdf/1604.04382.pdf).
  7. Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation. International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015 (https://arxiv.org/pdf/1505.04597.pdf).
  8. Judy Hoffman et al.: CyCADA: Cycle-Consistent Adversarial Domain Adaptation. arXiv preprint arXiv:1711.03213, 2017 (https://arxiv.org/pdf/1711.03213.pdf).

八、变分自编码器(VAE)

与我们在之前的章节中讨论过的生成对抗网络GAN)类似,变分自编码器VAE)[1] 属于生成模型家族。 VAE 的生成器能够在导航其连续潜在空间的同时产生有意义的输出。 通过潜向量探索解码器输出的可能属性。

在 GAN 中,重点在于如何得出近似输入分布的模型。 VAE 尝试对可解码的连续潜在空间中的输入分布进行建模。 这是 GAN 与 VAE 相比能够生成更真实信号的可能的潜在原因之一。 例如,在图像生成中,GAN 可以生成看起来更逼真的图像,而相比之下,VAE 生成的图像清晰度较差。

在 VAE 中,重点在于潜在代码的变分推理。 因此,VAE 为潜在变量的学习和有效贝叶斯推理提供了合适的框架。 例如,带有解缠结表示的 VAE 可以将潜在代码重用于迁移学习。

在结构上,VAE 与自编码器相似。 它也由编码器(也称为识别或推理模型)和解码器(也称为生成模型)组成。 VAE 和自编码器都试图在学习潜向量的同时重建输入数据。

但是,与自编码器不同,VAE 的潜在空间是连续的,并且解码器本身被用作生成模型。

在前面各章中讨论的 GAN 讨论中,也可以对 VAE 的解码器进行调整。 例如,在 MNIST 数据集中,我们能够指定一个给定的单热向量产生的数字。 这种有条件的 VAE 类别称为 CVAE [2]。 也可以通过在损失函数中包含正则化超参数来解开 VAE 潜向量。 这称为 β-VAE [5]。 例如,在 MNIST 中,我们能够隔离确定每个数字的粗细或倾斜角度的潜向量。 本章的目的是介绍:

  • VAE 的原理
  • 了解重新参数化技巧,有助于在 VAE 优化中使用随机梯度下降
  • 有条件的 VAE(CVAE)和 β-VAE 的原理
  • 了解如何使用tf.keras实现 VAE

我们将从谈论 VAE 的基本原理开始。

1. VAE 原理

在生成模型中,我们经常对使用神经网络来逼近输入的真实分布感兴趣:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rVHL8wEO-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_003.png)] (Equation 8.1.1)

在前面的等式中,θ表示训练期间确定的参数。 例如,在名人面孔数据集的上下文中,这等效于找到可以绘制面孔的分布。 同样,在 MNIST 数据集中,此分布可以生成可识别的手写数字。

在机器学习中,为了执行特定级别的推理,我们有兴趣寻找P[θ](x, z),这是输入x和潜在变量z之间的联合分布。 潜在变量不是数据集的一部分,而是对可从输入中观察到的某些属性进行编码。 在名人面孔的背景下,这些可能是面部表情,发型,头发颜色,性别等。 在 MNIST 数据集中,潜在变量可以表示数字和书写样式。

P[θ](x, z)实际上是输入数据点及其属性的分布。 P[θ](x)可以从边际分布计算得出:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8ycC5rg0-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_010.png)] (Equation 8.1.2)

换句话说,考虑所有可能的属性,我们最终得到描述输入的分布。 在名人面孔中,如果考虑所有面部表情,发型,头发颜色和性别,将恢复描述名人面孔的分布。 在 MNIST 数据集中,如果考虑所有可能的数字,书写风格等,我们以手写数字的分布来结束。

问题在于“公式 8.1.2”很难处理。 该方程式没有解析形式或有效的估计量。 它的参数无法微分。 因此,通过神经网络进行优化是不可行的。

使用贝叶斯定理,我们可以找到“公式 8.1.2”的替代表达式:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Z7xXxJa2-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_011.png)] (Equation 8.1.3)

P(z)z的先验分布。 它不以任何观察为条件。 如果z是离散的,而P[θ](x | z)是高斯分布,则P[θ](x)是高斯的混合。 如果z是连续的,则P[θ](x)是高斯的无限混合。

实际上,如果我们尝试在没有合适的损失函数的情况下建立一个近似P[θ](x | z)的神经网络,它将忽略z得出一个简单的解P[θ](x | z) = P[θ](x)。 因此,“公式 8.1.3”无法为我们提供P[θ](x)的良好估计。 或者,“公式 8.1.2”也可以表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IGZWILHU-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_024.png)] (Equation 8.1.4)

但是,P[θ](z | x)也很棘手。 VAE 的目标是在给定输入的情况下,找到一种可预测的分布,该分布易于估计P[θ](z | x),即潜在属性z的条件分布的估计。

变分推理

为了使易于处理,VAE 引入了变化推理模型(编码器):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iVjUbONa-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_030.png)] (Equation 8.1.5)

Q[φ](z | x)提供了P[θ](z | x)的良好估计。 它既参数化又易于处理。 Q[φ](z | x)可以通过优化参数φ由深度神经网络近似。 通常,Q[φ](z | x)被选择为多元高斯:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jroqc1Kl-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_036.png)] (Equation 8.1.6)

均值μ(x)和标准差σ(x)均由编码器神经网络使用输入数据点计算得出。 对角线矩阵表示z的元素是独立的。

在下一节中,我们将求解 VAE 的核心方程。 核心方程式将引导我们找到一种优化算法,该算法将帮助我们确定推理模型的参数。

核心方程

推理模型Q[φ](z | x)从输入x生成潜向量zQ[φ](z | x)似于自编码器模型中的编码器。 另一方面,从潜在代码z重构输入。 P[θ](x | z)的作用类似于自编码器模型中的解码器。 要估计P[θ](x),我们必须确定其与Q[φ](z | x)P[θ](x | z)的关系。

如果Q[φ](z | x)P[θ](z | x)的估计值,则 Kullback-LeiblerKL)的差异决定了这两个条件密度之间的距离:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uHJuQakp-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_052.png)] (Equation 8.1.7)

使用贝叶斯定理:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XwjhoFzr-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_053.png)] (Equation 8.1.8)

在“公式 8.1.7”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zVLcRvhb-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_054.png)] (Equation 8.1.9)

由于log P[θ](x)不依赖于z ~ Q,因此可能会超出预期。 重新排列“公式 8.1.9”并认识到:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lDcfbONN-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_057.png)],其结果是:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5VjPOOmf-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_058.png)] (Equation 8.1.10)

“公式 8.1.10”是 VAE 的核心。 左侧是项P[θ](x),由于Q[φ](z | x)与真实P[θ](z | x)的距离,我们使误差最小化。 我们可以记得,的对数不会更改最大值(或最小值)的位置。 给定提供P[θ](z | x)良好估计的推断模型,D[KL](Q[φ](z | x) || P[θ](z | x)大约为零。

右边的第一项P[θ](x | z)类似于解码器,该解码器从推理模型中抽取样本以重建输入。

第二个项是另一个距离。 这次是在Q[φ](z | x)和先前的P[θ](z)之间。 “公式 8.1.10”的左侧也称为变异下界证据下界ELBO)。 由于 KL 始终为正,因此 ELBO 是log P[θ](x)的下限。 通过优化神经网络的参数φθ来最大化 ELBO 意味着:

  • 在将z中的x属性编码时,D[KL](Q[φ](z | x) || P[θ](z | x) -> 0或推理模型变得更好。
  • 右边的log P[θ](x | z)最大化了“公式 8.1.10”或解码器模型在从潜在向量z重构x方面变得更好。
  • 在下一节中,我们将利用“公式 8.1.10”的结构来确定推理模型(编码器)和解码器的损失函数。

优化

“公式 8.1.10”的右侧具有有关 VAE 的loss函数的两个重要信息。 解码器项E[z~Q] [log P[θ](x | z)]表示生成器从推理模型的输出中提取z个样本,以重建输入。 使最大化是指我们将重构损失L_R降到最低。 如果假设图像(数据)分布为高斯分布,则可以使用 MSE。

如果每个像素(数据)都被认为是伯努利分布,那么损失函数就是二进制互熵。

第二项-D[KL](Q[φ](z | x) || P[θ](z))易于评估。 根据“公式 8.1.6”,Q[φ]是高斯分布。 通常,P[θ](z) = P(z) = N(0, 1)也是平均值为零且标准差等于 1.0 的高斯。 在“公式 8.1.11”中,我们看到 KL 项简化为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pjRfealv-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_082.png)] (Equation 8.1.11)

其中Jz的维。 μ[j]σ[j]都是通过推理模型计算的x的函数。 要最大化:-D[KL]σ[j] -> 1μ[j] -> 9P(z) = N(0, 1)的选择源于各向同性单位高斯的性质,在具有适当函数的情况下,它可以变形为任意分布[6]。

根据“公式 8.1.11”,KL 损失L_KL简称为D[KL]

总之,在“公式 8.1.12”中将 VAE loss函数定义为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dQ91SFSk-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_094.png)] (Equation 8.1.12)

在给定编码器和解码器模型的情况下,在我们可以构建和训练 VAE(随机采样块,生成潜在属性)之前,还需要解决一个问题。 在下一节中,我们将讨论此问题以及如何使用重新参数化技巧解决它。

重新参数化技巧

“图 8.1.1”的左侧显示了 VAE 网络。 编码器获取输入x,并估计潜向量z的多元高斯分布的平均值μ和标准差σ。 解码器从潜向量z中提取样本,以将输入重构为x_tilde。 这似乎很简单,直到在反向传播期间发生梯度更新为止:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RHX1ocUI-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_01.png)]

图 8.1.1:带有和不带有重新参数化技巧的 VAE 网络

反向传播梯度将不会通过随机采样块。 尽管具有用于神经网络的随机输入是可以的,但梯度不可能穿过随机层。

解决此问题的方法是将采样处理作为输入,如“图 8.1.1”右侧所示。 然后,将样本计算为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h61Cc07u-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_101.png)] (Equation 8.1.13)

如果εσ以向量格式表示,则εσ是逐元素乘法。 使用“公式 8.1.13”,看起来好像采样直接来自潜在空间一样。 这项技术被称为重新参数化技巧

现在,在输入端发生采样时,可以使用熟悉的优化算法(例如 SGD,Adam 或 RMSProp)来训练 VAE 网络。

在讨论如何在tf.keras中实现 VAE 之前,让我们首先展示如何测试经过训练的解码器。

解码器测试

在训练了 VAE 网络之后,可以丢弃推理模型,包括加法和乘法运算符。 为了生成新的有意义的输出,请从用于生成ε的高斯分布中抽取样本。“图 8.1.2”向我们展示了解码器的测试设置:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-chhoz0du-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_02.png)]

图 8.1.2:解码器测试设置

通过重新参数化技巧解决了 VAE 上的最后一个问题,我们现在可以在tf.keras中实现和训练变分自编码器。

ALAS 与 Keras

VAE 的结构类似于典型的自编码器。 区别主要在于重新参数化技巧中的高斯随机变量的采样。“列表 8.1.1”显示了使用 MLP 实现的编码器,解码器和 VAE。

此代码也已添加到官方 Keras GitHub 存储库中。

为便于显示潜在代码,将z的维设置为 2。编码器仅是两层 MLP,第二层生成均值和对数方差。 对数方差的使用是为了简化 KL 损失和重新参数化技巧的计算。 编码器的第三个输出是使用重新参数化技巧进行的z采样。 我们应该注意,在采样函数exp(0.5 log σ²) = sqrt(σ²) = σ中,因为σ > 0假定它是高斯分布的标准差。

解码器也是两层 MLP,它采用z的样本来近似输入。 编码器和解码器均使用大小为 512 的中间尺寸。

VAE 网络只是将编码器和解码器连接在一起。 loss函数是重建损失KL 损失的总和。 在默认的 Adam 优化器上,VAE 网络具有良好的效果。 VAE 网络中的参数总数为 807,700。

VAE MLP 的 Keras 代码具有预训练的权重。 要测试,我们需要运行:

代码语言:javascript复制
python3 vae-mlp-mnist-8.1.1.py --weights=vae_mlp_mnist.tf 

完整的代码可以在以下链接中找到。

“列表 8.1.1”:vae-mlp-mnist-8.1.1.py

代码语言:javascript复制
# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean   sqrt(var)*eps
def sampling(args):
    """Reparameterization trick by sampling 
        fr an isotropic unit Gaussian. 
代码语言:javascript复制
 # Arguments:
        args (tensor): mean and log of variance of Q(z|X) 
代码语言:javascript复制
 # Returns:
        z (tensor): sampled latent vector
    """ 
代码语言:javascript复制
 z_mean, z_log_var = args
    # K is the keras backend
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean   K.exp(0.5 * z_log_var) * epsilon 
代码语言:javascript复制
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
代码语言:javascript复制
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255 
代码语言:javascript复制
# network parameters
input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50 
代码语言:javascript复制
# VAE model = encoder   decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x) 
代码语言:javascript复制
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var]) 
代码语言:javascript复制
# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder') 
代码语言:javascript复制
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x) 
代码语言:javascript复制
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp') 
代码语言:javascript复制
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load tf model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use binary cross entropy instead of mse (default)"
    parser.add_argument("--bce", help=help_, action='store_true')
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test) 
代码语言:javascript复制
 # VAE loss = mse_loss or xent_loss   kl_loss
    if args.bce:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)
    else:
        reconstruction_loss = mse(inputs, outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1   z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss   kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam') 

“图 8.1.3”显示了编码器模型,它是一个 MLP,具有两个输出,即潜向量的均值和方差。 lambda 函数实现了重新参数化技巧,将随机潜在代码的采样推送到 VAE 网络之外:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rDkvi3aw-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_03.png)]

图 8.1.3:VAE MLP 的编码器模型

“图 8.1.4”显示了解码器模型。 2 维输入来自 lambda 函数。 输出是重构的 MNIST 数字:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MiTE41SX-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_04.png)]

图 8.1.4:VAE MLP 的解码器模型

“图 8.1.5”显示了完整的 VAE 模型。 通过将编码器和解码器模型结合在一起制成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M5hizOkT-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_05.png)]

图 8.1.5:使用 MLP 的 VAE 模型

“图 8.1.6”显示了使用plot_results()在 50 个周期后潜向量的连续空间。 为简单起见,此函数未在此处显示,但可以在vae-mlp-mnist-8.1.1.py的其余代码中找到。 该函数绘制两个图像,即测试数据集标签(“图 8.1.6”)和样本生成的数字(“图 8.1.7”),这两个图像都是z的函数。 这两个图都说明了潜在向量如何确定所生成数字的属性:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Qe9OO9Ve-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_06.png)]

图 8.1.6:MNIST 数字标签作为测试数据集(VAE MLP)的潜在向量平均值的函数。 原始图像可以在该书的 GitHub 存储库中找到。

浏览时,连续空格始终会产生与 MNIST 数字相似的输出。 例如,数字 9 的区域接近数字 7 的区域。从中心附近的 9 移动到左下角会将数字变形为 7。从中心向上移动会将生成的数字从 3 更改为 5,最后变为 0.数字的变形在“图 8.1.7”中更明显,这是解释“图 8.1.6”的另一种方式。

在“图 8.1.7”中,显示生成器输出。 显示了潜在空间中数字的分布。 可以观察到所有数字都被表示。 由于中心附近分布密集,因此变化在中间迅速,在平均值较高的区域则缓慢。 我们需要记住,“图 8.1.7”是“图 8.1.6”的反映。 例如,数字 0 在两个图的左上象限中,而数字 1 在右下象限中。

“图 8.1.7”中存在一些无法识别的数字,尤其是在右上象限中。 从“图 8.1.6”可以看出,该区域大部分是空的,并且远离中心:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ykRPrEwz-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_07.png)]

图 8.1.7:根据潜在向量平均值(VAE MLP)生成的数字。 为了便于解释,均值的范围类似于图 8.1.6

在本节中,我们演示了如何在 MLP 中实现 VAE。 我们还解释了导航潜在空间的结果。 在的下一部分中,我们将使用 CNN 实现相同的 VAE。

带有 CNN 的 AE

在原始论文《自编码变分贝叶斯》[1]中,使用 MLP 来实现 VAE 网络,这与我们在上一节中介绍的类似。 在本节中,我们将证明使用 CNN 将显着提高所产生数字的质量,并将参数数量显着减少至 134,165。

“列表 8.1.3”显示了编码器,解码器和 VAE 网络。 该代码也被添加到了官方的 Keras GitHub 存储库中。

为简洁起见,不再显示与 MLP VAE 类似的某些代码行。 编码器由两层 CNN 和两层 MLP 组成,以生成潜在代码。 编码器的输出结构与上一节中看到的 MLP 实现类似。 解码器由一层Dense和三层转置的 CNN 组成。

VAE CNN 的 Keras 代码具有预训练的权重。 要测试,我们需要运行:

代码语言:javascript复制
python3 vae-cnn-mnist-8.1.2.py --weights=vae_cnn_mnist.tf 

“列表 8.1.3”:vae-cnn-mnist-8.1.2.py

使用 CNN 层的tf.keras中的 VAE:

代码语言:javascript复制
# network parameters
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30 
代码语言:javascript复制
# VAE model = encoder   decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x) 
代码语言:javascript复制
# shape info needed to build decoder model
shape = K.int_shape(x) 
代码语言:javascript复制
# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x) 
代码语言:javascript复制
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var]) 
代码语言:javascript复制
# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder') 
代码语言:javascript复制
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3],
          activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x) 
代码语言:javascript复制
for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2 
代码语言:javascript复制
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x) 
代码语言:javascript复制
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder') 
代码语言:javascript复制
# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae') 

“图 8.1.8”显示了 CNN 编码器模型的两个输出,即潜向量的均值和方差。 lambda 函数实现了重新参数化技巧,将随机潜码的采样推送到 VAE 网络之外:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rTFwOwuA-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_08.png)]

图 8.1.8:VAE CNN 的编码器

“图 8.1.9”显示了 CNN 解码器模型。 2 维输入来自 lambda 函数。 输出是重构的 MNIST 数字:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wd7Nf6eD-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_09.png)]

图 8.1.9:VAE CNN 的解码器

“图 8.1.10”显示完整的 CNN VAE 模型。 通过将编码器和解码器模型结合在一起制成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WMfD8AAr-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_10.png)]

图 8.1.10:使用 CNN 的 VAE 模型

对 VAE 进行了 30 个周期的训练。“图 8.1.11”显示了在导航 VAE 的连续潜在空间时数字的分布。 例如,从中间到右边从 2 变为 0:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BIkuhCUR-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_11.png)]

图 8.1.11:MNIST 数字标签作为测试数据集(VAE CNN)的潜在向量平均值的函数。 原始图像可以在该书的 GitHub 存储库中找到。

“图 8.1.12”向我们展示了生成模型的输出。 从质量上讲,与“图 8.1.7”(具有 MLP 实现)相比,模棱两可的位数更少:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5l35MO5C-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_12.png)]

图 8.1.12:根据潜在向量平均值(VAE CNN)生成的数字。 为了便于解释,均值的范围类似于图 8.1.11

前的两节讨论了使用 MLP 或 CNN 的 VAE 的实现。 我们分析了两种实现方式的结果,结果表明 CNN 可以减少参数数量并提高感知质量。 在下一节中,我们将演示如何在 VAE 中实现条件,以便我们可以控制要生成的数字。

2. 条件 VAE(CVAE)

有条件的 VAE [2]与 CGAN 相似。 在 MNIST 数据集的上下文中,如果随机采样潜在空间,则 VAE 无法控制将生成哪个数字。 CVAE 可以通过包含要产生的数字的条件(单标签)来解决此问题。 该条件同时施加在编码器和解码器输入上。

正式地,将“公式 8.1.10”中 VAE 的核心公式修改为包括条件c

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KQcqohNE-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_113.png)] (Equation 8.2.1)

与 VAE 相似,“公式 8.2.1”表示如果要最大化输出条件cP[θ](x | c),则必须最小化两个损失项:

  • 给定潜在向量和条件,解码器的重建损失。
  • 给定潜在向量和条件的编码器之间的 KL 损失以及给定条件的先验分布。 与 VAE 相似,我们通常选择P[θ](x | c) = P(x | c) = N(0, 1)

实现 CVAE 需要对 VAE 的代码进行一些修改。 对于 CVAE,使用 VAE CNN 实现是因为它可以形成一个较小的网络,并产生感知上更好的数字。

“列表 8.2.1”突出显示了针对 MNIST 数字的 VAE 原始代码所做的更改。 编码器输入现在是原始输入图像及其单标签的连接。 解码器输入现在是潜在空间采样与其应生成的图像的一键热标签的组合。 参数总数为 174,437。 与 β-VAE 相关的代码将在本章下一节中讨论。

损失函数没有改变。 但是,在训练,测试和结果绘制过程中会提供单热标签。

“列表 8.2.1”:cvae-cnn-mnist-8.2.1.py

tf.keras中使用 CNN 层的 CVAE。 重点介绍了为支持 CVAE 而进行的更改:

代码语言:javascript复制
# compute the number of labels
num_labels = len(np.unique(y_train)) 
代码语言:javascript复制
# network parameters
input_shape = (image_size, image_size, 1)
label_shape = (num_labels, )
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30 
代码语言:javascript复制
# VAE model = encoder   decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
y_labels = Input(shape=label_shape, name='class_labels')
x = Dense(image_size * image_size)(y_labels)
x = Reshape((image_size, image_size, 1))(x)
x = keras.layers.concatenate([inputs, x])
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x) 
代码语言:javascript复制
# shape info needed to build decoder model
shape = K.int_shape(x) 
代码语言:javascript复制
# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x) 
代码语言:javascript复制
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var]) 
代码语言:javascript复制
# instantiate encoder model
encoder = Model([inputs, y_labels],
                [z_mean, z_log_var, z],
                name='encoder') 
代码语言:javascript复制
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = concatenate([latent_inputs, y_labels])
x = Dense(shape[1]*shape[2]*shape[3], activation='relu')(x)
x = Reshape((shape[1], shape[2], shape[3]))(x) 
代码语言:javascript复制
for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2 
代码语言:javascript复制
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x) 
代码语言:javascript复制
# instantiate decoder model
decoder = Model([latent_inputs, y_labels],
                outputs,
                name='decoder')
# instantiate vae model
outputs = decoder([encoder([inputs, y_labels])[2], y_labels])
cvae = Model([inputs, y_labels], outputs, name='cvae') 

“图 8.2.1”显示了 CVAE 模型的编码器。 附加输入,即单热向量class_labels形式的条件标签表示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bLvR6Phc-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_13.png)]

图 8.2.1:CVAE CNN 中的编码器。 输入现在包括 VAE 输入和条件标签的连接

“图 8.2.2”显示了 CVAE 模型的解码器。 附加输入,即单热向量class_labels形式的条件标签表示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1AJ1jTMJ-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_14.png)]

图 8.2.2:CVAE CNN 中的解码器。 输入现在包括 z 采样和条件标签的连接

“图 8.2.3”显示了完整的 CVAE 模型,该模型是编码器和解码器结合在一起的。 附加输入,即单热向量class_labels形式的条件标签:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yFcG8oOh-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_15.png)]

图 8.2.3:使用 CNN 的 CVAE 模型。输入现在包含一个 VAE 输入和一个条件标签

在“图 8.2.4”中,每个标记的平均值分布在 30 个周期后显示。 与前面章节中的“图 8.1.6”和“图 8.1.11”不同,每个标签不是集中在一个区域上,而是分布在整个图上。 这是预期的,因为潜在空间中的每个采样都应生成一个特定的数字。 浏览潜在空间会更改该特定数字的属性。 例如,如果指定的数字为 0,则在潜伏空间中导航仍将产生 0,但是诸如倾斜角度,厚度和其他书写样式方面的属性将有所不同。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tB1jiTpc-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_16.png)]

图 8.2.4:作为测试数据集(CVAE CNN)的潜在向量平均值的函数的 MNIST 数字标签。 原始图像可以在该书的 GitHub 存储库中找到。

“图 8.2.4”在“图 8.2.5”中更清楚地显示,数字 0 到 5。每个帧都有相同的数字,并且属性在我们浏览时顺畅地变化。 潜在代码:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lKyT6wAx-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_17.png)]

图 8.2.5:根据潜在向量平均值和单热点标签(CVAE CNN)生成的数字 0 至 5。 为了便于解释,均值的范围类似于图 8.2.4。

“图 8.2.6”显示“图 8.2.4”,用于数字 6 至 9:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fr1yGOiU-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_18.png)]

图 8.2.6:根据潜在向量平均值和单热点标签(CVAE CNN)生成的数字 6 至 9。 为了便于解释,均值的范围类似于图 8.2.4。

为了便于比较,潜向量的值范围与“图 8.2.4”中的相同。 使用预训练的权重,可以通过执行以下命令来生成数字(例如 0):

代码语言:javascript复制
python3 cvae-cnn-mnist-8.2.1.py –bce --weights=cvae_cnn_mnist.tf --digit=0 

在“图 8.2.5”和“图 8.2.6”中,可以注意到,每个数字的宽度和圆度(如果适用)随z[0]的变化而变化。 从左到右追踪。 同时,当z[1]从上到下导航时,每个数字的倾斜角度和圆度(如果适用)也会发生变化。 随着我们离开分布中心,数字的图像开始退化。 这是可以预期的,因为潜在空间是一个圆形。

属性中其他明显的变化可能是数字特定的。 例如,数字 1 的水平笔划(手臂)在左上象限中可见。 数字 7 的水平笔划(纵横线)只能在右象限中看到。

在下一节中,我们将发现 CVAE 实际上只是另一种称为 β-VAE 的 VAE 的特例。

3. β-VAE – 具有纠缠的潜在表示形式的 VAE

在“第 6 章”,“非纠缠表示 GAN”中,讨论了潜码非纠缠表示的概念和重要性。 我们可以回想起,一个纠缠的表示是单个潜伏单元对单个生成因子的变化敏感,而相对于其他因子的变化相对不变[3]。 更改潜在代码会导致生成的输出的一个属性发生更改,而其余属性保持不变。

在同一章中,InfoGAN [4]向我们展示了对于 MNIST 数据集,可以控制生成哪个数字以及书写样式的倾斜度和粗细。 观察上一节中的结果,可以注意到,VAE 在本质上使潜向量维解开了一定程度。 例如,查看“图 8.2.6”中的数字 8,从上到下导航z[1]会减小宽度和圆度,同时顺时针旋转数字。 从左至右增加z[0]也会在逆时针旋转数字时减小宽度和圆度。 换句话说,z[1]控制顺时针旋转,而z[0]影响逆时针旋转,并且两者都改变​​宽度和圆度。

在本节中,我们将演示对 VAE 损失函数的简单修改会迫使潜在代码进一步解开纠缠。 修改为正恒重β > 1,用作 KL 损失的调节器:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZFv9spjc-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_121.png)] (Equation 8.3.1)

VAE 的这种变化称为 β-VAE [5]。 β的隐含效果是更严格的标准差。 换句话说,β强制后验分布中的潜码Q[φ](z | x)独立。

实现 β-VAE 很简单。 例如,对于上一个示例中的 CVAE,所需的修改是kl_loss中的额外beta因子:

代码语言:javascript复制
kl_loss = 1   z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5 * beta 

CVAE 是 β-VAE 的特例,其中β = 1。 其他一切都一样。 但是,确定的值需要一些反复试验。 为了潜在的代码独立性,在重构误差和正则化之间必须有一个仔细的平衡。 解缠最大在β = 9附近。 当中β = 9的值时,β-VAE 仅被迫学习一个解纠缠的表示,而忽略另一个潜在维度。

“图 8.3.1”和“图 8.3.2”显示 β-VAE 的潜向量平均值,其中β = 9β = 10

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lrcWwni7-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_19.png)]

图 8.3.1:MNIST 数字标签与测试数据集的潜在向量平均值的函数(β-VAE,β = 9)。 原始图像可以在该书的 GitHub 存储库中找到。

β = 9时,与 CVAE 相比,分布具有较小的标准差。 在β = 10的情况下,仅学习了潜在代码。 分布实际上缩小为一个维度,编码器和解码器忽略了第一潜码z[0]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0IexeCU-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_20.png)]

图 8.3.2:MNIST 数字标签与测试数据集的潜向量平均值的函数(β-VAE 和β = 10

原始图像可以在该书的 GitHub 存储库中找到。

这些观察结果反映在“图 8.3.3”中。 具有β = 9的 β-VAE 具有两个实际上独立的潜在代码。 z[0]确定书写样式的倾斜度,而z[1]指定数字的宽度和圆度(如果适用)。 对于中β = 10的 β-VAE,z[0]被静音。 z[0]的增加不会显着改变数字。z[1]确定书写样式的倾斜角度和宽度:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-esbGFOdP-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_21.png)]

图 8.3.3:根据潜在向量平均值和单热点标签(β-VAE,β = 1, 9, 10)生成的数字 0 至 3。 为了便于解释,均值的范围类似于图 8.3.1。

β-VAE 的tf.keras代码具有预训练的权重。 要使用β = 9生成数字 0 来测试 β-VAE,我们需要运行以下命令:

代码语言:javascript复制
python3 cvae-cnn-mnist-8.2.1.py --beta=9 --bce --weights=beta-cvae_cnn_mnist.tf --digit=0 

总而言之,我们已经证明与 GAN 相比,在 β-VAE 上更容易实现解缠表示学习。 我们所需要做的就是调整单个超参数。

4. 总结

在本章中,我们介绍了 VAE 的原理。 正如我们从 VAE 原理中学到的那样,从两次尝试从潜在空间创建合成输出的角度来看,它们都与 GAN 相似。 但是,可以注意到,与 GAN 相比,VAE 网络更简单,更容易训练。 越来越清楚的是 CVAE 和 β-VAE 在概念上分别类似于条件 GAN 和解缠表示 GAN。

VAE 具有消除潜在向量纠缠的内在机制。 因此,构建 β-VAE 很简单。 但是,我们应该注意,可解释和解开的代码对于构建智能体很重要。

在下一章中,我们将专注于强化学习。 在没有任何先验数据的情况下,智能体通过与周围的世界进行交互来学习。 我们将讨论如何为智能体的正确行为提供奖励,并为错误的行为提供惩罚。

5. 参考

  1. Diederik P. Kingma and Max Welling. Auto-encoding Variational Bayes. arXiv preprint arXiv:1312.6114, 2013 (https://arxiv.org/pdf/1312.6114.pdf).
  2. Kihyuk Sohn, Honglak Lee, and Xinchen Yan. Learning Structured Output Representation Using Deep Conditional Generative Models. Advances in Neural Information Processing Systems, 2015 (http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models.pdf).
  3. Yoshua Bengio, Aaron Courville, and Pascal Vincent. Representation Learning.
  4. A Review and New Perspectives. IEEE transactions on Pattern Analysis and Machine Intelligence 35.8, 2013: 1798-1828 (https://arxiv.org/pdf/1206.5538.pdf).
  5. Xi Chen et al.: Infogan: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets. Advances in Neural Information Processing Systems, 2016 (http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf).
  6. I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner. -VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. ICLR, 2017 (https://openreview.net/pdf?id=Sy2fzU9gl).
  7. Carl Doersch. Tutorial on variational autoencoders. arXiv preprint arXiv:1606.05908, 2016 (https://arxiv.org/pdf/1606.05908.pdf).

九、深度强化学习

强化学习RL)是智能体程序用于决策的框架。 智能体不一定是软件实体,例如您在视频游戏中可能看到的那样。 相反,它可以体现在诸如机器人或自动驾驶汽车之类的硬件中。 内在的智能体可能是充分理解和利用 RL 的最佳方法,因为物理实体与现实世界进行交互并接收响应。

该智能体位于环境中。 环境具有状态,可以部分或完全观察到。 该智能体具有一组操作,可用于与环境交互。 动作的结果将环境转换为新状态。 执行动作后,会收到相应的标量奖励

智能体的目标是通过学习策略来最大化累积的未来奖励,该策略将决定在特定状态下应采取的行动。

RL 与人类心理学有很强的相似性。 人类通过体验世界来学习。 错误的行为会导致某种形式的惩罚,将来应避免使用,而正确的行为应得到奖励并应予以鼓励。 这种与人类心理学的强相似之处使许多研究人员相信 RL 可以将引向真正的人工智能AI)。

RL 已经存在了几十年。 但是,除了简单的世界模型之外,RL 还在努力扩展规模。 这是,其中深度学习DL)开始发挥作用。 它解决了这个可扩展性问题,从而开启了深度强化学习DRL)的时代。 在本章中,我们的重点是 DRL。 DRL 中值得注意的例子之一是 DeepMind 在智能体上的工作,这些智能体能够在不同的视频游戏上超越最佳的人类表现。

在本章中,我们将讨论 RL 和 DRL。

总之,本章的目的是介绍:

  • RL 的原理
  • RL 技术,Q 学习
  • 高级主题,包括深度 Q 网络DQN)和双重 Q 学习DDQN
  • 关于如何使用tf.keras在 Python 和 DRL 上实现 RL 的说明

让我们从 RL 的基本原理开始。

1. 强化学习原理(RL)

“图 9.1.1”显示了用于描述 RL 的感知动作学习循环。 环境是苏打水可以坐在地板上。 智能体是一个移动机器人,其目标是拾取苏打水。 它观察周围的环境,并通过车载摄像头跟踪汽水罐的位置。 观察结果以一种状态的形式进行了汇总,机器人将使用该状态来决定要采取的动作。 所采取的动作可能与低级控制有关,例如每个车轮的旋转角度/速度,手臂的每个关节的旋转角度/速度以及抓手是打开还是关闭。

可替代地,动作可以是高级控制动作,诸如向前/向后移动机器人,以特定角度转向以及抓取/释放。 将夹持器从汽水中移开的任何动作都会得到负回报。 缩小抓取器位置和苏打之间的缝隙的任何动作都会获得积极的回报。 当机械臂成功捡起汽水罐时,它会收到丰厚的回报。 RL 的目标是学习最佳策略,该策略可帮助机器人决定在给定状态下采取哪种行动以最大化累积的折扣奖励:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bgGcFrjL-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_01.png)]

图 9.1.1:RL 中的感知-动作-学习循环

形式上,RL 问题可以描述为 Markov 决策过程MDP)。

为简单起见,我们将假定为确定性环境,在该环境中,给定状态下的某个动作将始终导致已知的下一个状态和奖励。 在本章的后面部分,我们将研究如何考虑随机性。 在时间步t时:

  • 环境处于状态空间S的状态下,状态s[0],该状态可以是离散的也可以是连续的。 起始状态为s[0],而终止状态为s[T]
  • 智能体通过遵循策略π(a[t] | s[t])从操作空间A采取操作,即s[a]A可以是离散的或连续的。
  • 环境使用状态转换动态T(s[t 1] | s[t], a[t])转换为新状态,s[t 1]。 下一个状态仅取决于当前状态和操作。 智能体不知道T
  • 智能体使用奖励函数接收标量奖励,r[t 1] = R(s[t], a[t]),以及r: A x S -> R。 奖励仅取决于当前状态和操作。 智能体不知道R
  • 将来的奖励折扣为γ^k,其中γ ∈ [0, 1]k是未来的时间步长。
  • 地平线H是完成从s[0]s[T]的一集所需的时间步长T

该环境可以是完全或部分可观察的。 后者也称为部分可观察的 MDPPOMDP。 在大多数情况下,完全观察环境是不现实的。 为了提高的可观察性,当前的观测值也考虑了过去的观测值。 状态包括对环境的足够观察,以使策略决定采取哪种措施。 回忆“图 9.1.1”,这可能是汽水罐相对于机器人抓手的三维位置,如机器人摄像头所估计的那样。

每当环境转换到新状态时,智能体都会收到标量奖励r[t 1]。 在“图 9.1.1”中,每当机器人靠近汽水罐时,奖励可能为 1;当机器人离汽水罐更远时,奖励为 -1;当机器人关闭夹具并成功捡起苏打时,奖励为 100。 能够。 智能体的目标是学习一种最佳策略π*,该策略可使所有状态的收益最大化:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-f85hflJU-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_012.png)] (Equation 9.1.1)

回报定义为折扣累积奖励R[t] = Σ γ^t r[t k], k = 0, ..., T。 从“公式 9.1.1”可以看出,与通常的γ^k < 1.0相比,与立即获得的奖励相比,未来的奖励权重较低。 在极端情况下,当γ = 0时,仅立即获得奖励很重要。 当γ = 1时,将来的奖励与立即奖励的权重相同。

遵循任意策略π,可以将回报解释为对给定状态值的度量:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-byZLyWAj-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_019.png)] (Equation 9.1.2)

换句话说,RL 问题是智能体的目标,是学习使所有状态s最大化的最优策略V^π

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CTBNheSJ-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_021.png)] (Equation 9.1.3)

最优策略的值函数就是V*。 在“图 9.1.1”中,最佳策略是生成最短动作序列的一种,该动作序列使机器人越来越靠近苏打罐,直到被取走为止。 状态越接近目标状态,其值越高。 可以将导致目标(或最终状态)的事件序列建模为策略的轨迹部署

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-klKvBKtv-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_023.png)] (Equation 9.1.4)

如果 MDP 是偶发的,则当智能体到达终端状态s[T]时,状态将重置为s[0]。 如果T是有限的,则我们的水平范围是有限的。 否则,视野是无限的。 在“图 9.1.1”中,如果 MDP 是情景剧集,则在收集苏打罐后,机器人可能会寻找另一个苏打罐来拾取,并且 RL 问题重发。

因此,RL 的主要目标是找到一种使每个状态的值最大化的策略。 在下一部分中,我们将介绍可用于最大化值函数的策略学习算法。

2. Q 值

如果 RL 问题是找到π*,则智能体如何通过与环境交互来学习?“公式 9.1.3”并未明确指出尝试进行的操作以及计算收益的后续状态。 在 RL 中,使用 Q 值更容易学习π*

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tAIkVX5O-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_026.png)] (Equation 9.2.1)

哪里:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kltRYhYh-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_027.png)] (Equation 9.2.2)

换句话说,不是找到使所有状态的值最大化的策略,而是“公式 9.2.1”寻找使所有状态的质量(Q)值最大化的操作。 在找到 Q 值函数之后,分别由“公式 9.2.2”和“公式 9.1.3”确定V*,因此确定了π*

如果对于每个动作,都可以观察到奖励和下一状态,则可以制定以下迭代或反复试验算法来学习 Q 值:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eH928w5D-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_030.png)] (Equation 9.2.3)

为了简化符号,s'a'分别是下一个状态和动作。 “公式 9.2.3”被称为贝尔曼方程,它是 Q 学习算法的核心。 Q 学习尝试根据当前状态和作用来近似返回值或值的一阶展开(“公式 9.1.2”)。 从对环境动态的零知识中,智能体尝试执行操作a,观察以奖励r和下一个状态s'的形式发生的情况。 max[a'] Q(s', a')选择下一个逻辑动作,该动作将为下一个状态提供最大 Q 值。 有了“公式 9.2.3”中的所有项,该当前状态-动作对的 Q 值就会更新。 迭代地执行更新将最终使智能体能够学习 Q 值函数。

Q 学习是一种脱离策略 RL 算法。 它学习了如何通过不直接从策略中抽取经验来改进策略。 换句话说,Q 值的获取与智能体所使用的基础策略无关。 当 Q 值函数收敛时,才使用“公式 9.2.1”确定最佳策略。

在为提供有关如何使用 Q 学习的示例之前,请注意,智能体必须在不断利用其到目前为止所学知识的同时不断探索其环境。 这是 RL 中的问题之一-在探索开发之间找到适当的平衡。 通常,在学习开始时,动作是随机的(探索)。 随着学习的进行,智能体会利用 Q 值(利用)。 例如,一开始,90% 的动作是随机的,而 10% 的动作则来自 Q 值函数。 在每个剧集的结尾,这逐渐减少。 最终,该动作是 10% 随机的,并且是 Q 值函数的 90%。

在下一节中,我们将给出有关在简单的确定性环境中如何使用 Q 学习的具体示例。

3. Q 学习实例

为了说明 Q 学习算法,我们需要考虑一个简单的确定性环境,如图“图 9.3.1”所示。 环境具有六个状态。

显示允许的过渡的奖励。 在两种情况下,奖励是非零的。 转换为目标G)状态可获得 100 的奖励,同时移至H)状态具有 -100 奖励。 这两个状态是终端状态,从开始状态构成一个剧集的结尾:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4I0G4ROO-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_02.png)]

图 9.3.1:简单确定性世界中的奖励

为了使每个状态的身份正式化,我们使用(行, 列)标识符,如图“图 9.3.2”所示。 由于智能体尚未了解有关其环境的任何信息,因此“图 9.3.2”中所示的 Q 表的初始值为零。 在此示例中,折扣因子γ = 0.9。 回想一下,在当前 Q 值的估计中,折扣因子确定了未来 Q 值的权重,该权重是步数γ^k的函数。 在“公式 9.2.3”中,我们仅考虑近期 Q 值k = 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zA9USIJI-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_03.png)]

图 9.3.2:简单确定性环境中的状态和智能体的初始 Q 表

最初,智能体采用的策略是 90% 的时间选择随机操作,并 10% 的时间使用 Q 表。 假设第一个动作是随机选择的,并且指示向右移动。“图 9.3.3”说明了向右移动时状态(0, 0)的新 Q 值的计算。 下一个状态是(0, 1)。 奖励为 0,所有下一个状态的 Q 值的最大值为零。 因此,向右移动的状态(0, 0)的 Q 值保持为 0。

为了轻松跟踪初始状态和下一个状态,我们在环境和 Q 表上使用不同的灰色阴影-初始状态浅灰色,下一个状态灰色。

在为下一个状态选择下一个动作时,候选动作位于较粗的边框中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U7jqNt18-1681704311678)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_04.png)]

图 9.3.3:假设智能体采取的行动是向右移动,则显示状态(0, 0)的 Q 值的更新

假设下一个随机选择的动作是向下移动。“图 9.3.4”显示状态(0, 1)的 Q 值沿向下方向的移动没有​​变化:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VAzEl4n3-1681704311678)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_05.png)]

图 9.3.4:假设智能体选择的动作是向下移动,则显示状态(0, 1)的 Q 值的更新

在“图 9.3.5”中,智能体的第三个随机动作是向右移动。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YhiSUiyU-1681704311678)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_06.png)]

图 9.3.5:假设智能体选择的动作是向右移动,则显示状态(1, 1)的 Q 值的更新

它遇到了,H状态,并获得了 -100 奖励。 这次,更新不为零。 向右移动时,状态(1, 1)的新 Q 值为 -100。 注意,由于这是终端状态,因此没有下一个状态。 一集刚刚结束,智能体返回到开始状态。

假设智能体仍处于探索模式,如图“图 9.3.6”所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zswzSOpr-1681704311678)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_07.png)]

图 9.3.6:假设智能体选择的动作是向右连续两次移动,则显示状态(0, 1)的 Q 值的更新

为第二集采取的第一步是向右移动。 正如预期的那样,更新为 0。但是,它选择的第二个随机动作也是向右移动。 智能体到达G状态并获得 100 的巨额奖励。 向右移动的状态(0, 1)的 Q 值变为 100。完成第二集,并且智能体返回到启动状态。

在第三集开始时,智能体采取的随机行动是向右移动。 现在,状态(0, 0)的 Q 值将更新为非零值,因为下一个状态的可能动作将最大 Q 值设为 100。“图 9.3.7”显示了所涉及的计算。 下一个状态(0, 1)的 Q 值波动回到较早的状态(0, 0)。 这就像对帮助找到G状态的早期状态表示赞赏。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fhLzt2z4-1681704311678)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_08.png)]

图 9.3.7:假设智能体选择的动作是向右移动,则显示状态(0, 0)的 Q 值的更新

Q 表的进步很大。 实际上,在下一集中,如果由于某种原因该策略决定使用 Q 表而不是随机探索环境,则第一个动作是根据“图 9.3.8”中的计算向右移动。 在 Q 表的第一行中,导致最大 Q 值的动作是向右移动。 对于下一个状态(0, 1),Q 表的第二行表明下一个动作仍然是向右移动。 智能体已成功实现其目标。 该策略指导智能体采取了正确的措施来实现其目标:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3ohfEpte-1681704311679)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_09.png)]

图 9.3.8:在这种情况下,智能体的策略决定利用 Q 表来确定状态(0, 0)(0, 1)的动作。 Q 表建议两个状态都向右移动

如果 Q 学习算法继续无限期运行,则 Q 表将收敛。 收敛的假设是 RL 问题必须是具有有限奖励的确定性 MDP,并且所有状态都将被无限次地访问。

在下一节中,我们将使用 Python 模拟环境。 我们还将展示 Q 学习算法的代码实现。

用 Python 进行 Q 学习

上一节中讨论的环境和 Q 学习可以在 Python 中实现。 由于该策略只是一个简单的表,因此在此时,无需使用tf.keras库。“列表 9.3.1”显示了q-learning-9.3.1.py,它是使用QWorld类实现的简单确定性世界(环境,智能体,操作和 Q 表算法)的实现。 为简洁起见,未显示处理用户界面的函数。

在此示例中,环境动态由self.transition_table表示。 在每个动作中,self.transition_table确定下一个状态。 执行动作的奖励存储在self.reward_table中。 每次通过step()函数执行动作时,都要查阅这两个表。 Q 学习算法由update_q_table()函数实现。 每当智能体需要决定要采取的操作时,它都会调用act()函数。 策略可以使用 Q 表随机抽取或决定。 所选动作是随机的机会百分比存储在self.epsilon变量中,该变量由update_epsilon()函数使用固定的epsilon_decay更新。

在执行“列表 9.3.1”中的代码之前,我们需要运行:

代码语言:javascript复制
sudo pip3 install termcolor 

安装termcolor包。 该包有助于可视化终端上的文本输出。

完整的代码可以在 GitHub 上找到。

“列表 9.3.1”:q-learning-9.3.1.py

具有六个状态的简单确定性 MDP:

代码语言:javascript复制
from collections import deque
import numpy as np
import argparse
import os
import time
from termcolor import colored 
代码语言:javascript复制
class QWorld:
    def __init__(self):
        """Simulated deterministic world made of 6 states.
        Q-Learning by Bellman Equation. 
        """
        # 4 actions
        # 0 - Left, 1 - Down, 2 - Right, 3 - Up
        self.col = 4 
代码语言:javascript复制
 # 6 states
        self.row = 6 
代码语言:javascript复制
 # setup the environment
        self.q_table = np.zeros([self.row, self.col])
        self.init_transition_table()
        self.init_reward_table() 
代码语言:javascript复制
 # discount factor
        self.gamma = 0.9 
代码语言:javascript复制
 # 90% exploration, 10% exploitation
        self.epsilon = 0.9
        # exploration decays by this factor every episode
        self.epsilon_decay = 0.9
        # in the long run, 10% exploration, 90% exploitation
        self.epsilon_min = 0.1 
代码语言:javascript复制
 # reset the environment
        self.reset()
        self.is_explore = True 
代码语言:javascript复制
 def reset(self):
        """start of episode"""
        self.state = 0
        return self.state 
代码语言:javascript复制
 def is_in_win_state(self):
        """agent wins when the goal is reached"""
        return self.state == 2 
代码语言:javascript复制
 def init_reward_table(self):
        """
        0 - Left, 1 - Down, 2 - Right, 3 - Up
        ----------------
        | 0 | 0 | 100  |
        ----------------
        | 0 | 0 | -100 |
        ----------------
        """
        self.reward_table = np.zeros([self.row, self.col])
        self.reward_table[1, 2] = 100.
        self.reward_table[4, 2] = -100. 
代码语言:javascript复制
 def init_transition_table(self):
        """
        0 - Left, 1 - Down, 2 - Right, 3 - Up
        -------------
        | 0 | 1 | 2 |
        -------------
        | 3 | 4 | 5 |
        -------------
        """
        self.transition_table = np.zeros([self.row, self.col],
                                         dtype=int)
        self.transition_table[0, 0] = 0
        self.transition_table[0, 1] = 3
        self.transition_table[0, 2] = 1
        self.transition_table[0, 3] = 0 
代码语言:javascript复制
 self.transition_table[1, 0] = 0
        self.transition_table[1, 1] = 4
        self.transition_table[1, 2] = 2
        self.transition_table[1, 3] = 1 
代码语言:javascript复制
 # terminal Goal state
        self.transition_table[2, 0] = 2
        self.transition_table[2, 1] = 2
        self.transition_table[2, 2] = 2
        self.transition_table[2, 3] = 2 
代码语言:javascript复制
 self.transition_table[3, 0] = 3
        self.transition_table[3, 1] = 3
        self.transition_table[3, 2] = 4
        self.transition_table[3, 3] = 0 
代码语言:javascript复制
 self.transition_table[4, 0] = 3
        self.transition_table[4, 1] = 4
        self.transition_table[4, 2] = 5
        self.transition_table[4, 3] = 1 
代码语言:javascript复制
 # terminal Hole state
        self.transition_table[5, 0] = 5
        self.transition_table[5, 1] = 5
        self.transition_table[5, 2] = 5
        self.transition_table[5, 3] = 5 
代码语言:javascript复制
 def step(self, action):
        """execute the action on the environment
        Argument:
            action (tensor): An action in Action space
        Returns:
            next_state (tensor): next env state
            reward (float): reward received by the agent
            done (Bool): whether the terminal state 
                is reached
        """
        # determine the next_state given state and action
        next_state = self.transition_table[self.state, action]
        # done is True if next_state is Goal or Hole
        done = next_state == 2 or next_state == 5
        # reward given the state and action
        reward = self.reward_table[self.state, action]
        # the enviroment is now in new state
        self.state = next_state
        return next_state, reward, done 
代码语言:javascript复制
 def act(self):
        """determine the next action
            either fr Q Table(exploitation) or
            random(exploration)
        Return:
            action (tensor): action that the agent
                must execute
        """
        # 0 - Left, 1 - Down, 2 - Right, 3 - Up
        # action is from exploration
        if np.random.rand() <= self.epsilon:
            # explore - do random action
            self.is_explore = True
            return np.random.choice(4,1)[0] 
代码语言:javascript复制
 # or action is from exploitation
        # exploit - choose action with max Q-value
        self.is_explore = False
        action = np.argmax(self.q_table[self.state])
        return action 
代码语言:javascript复制
 def update_q_table(self, state, action, reward, next_state):
        """Q-Learning - update the Q Table using Q(s, a)
        Arguments:
            state (tensor) : agent state
            action (tensor): action executed by the agent
            reward (float): reward after executing action 
                for a given state
            next_state (tensor): next state after executing
                action for a given state
        """
        # Q(s, a) = reward   gamma * max_a' Q(s', a')
        q_value = self.gamma * np.amax(self.q_table[next_state])
        q_value  = reward
        self.q_table[state, action] = q_value 
代码语言:javascript复制
 def update_epsilon(self):
        """update Exploration-Exploitation mix"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay 

感知动作学习循环在“列表 9.3.2”中进行了说明。 在每个剧集中,环境都会重置为开始状态。 选择要执行的动作并将其应用于环境。 观察奖励下一个状态,并将其用于更新 Q 表。 达到目标状态后,剧集完成(done = True)。

对于此示例,Q 学习运行 100 集或 10 获胜,以先到者为准。 由于在每个剧集中变量的值均降低,因此智能体开始倾向于利用 Q 表来确定在给定状态下要执行的动作。 要查看 Q 学习模拟,我们只需要运行以下命令:

代码语言:javascript复制
python3 q-learning-9.3.1.py 

“列表 9.3.2”:q-learning-9.3.1.py

主要的 Q 学习循环:

代码语言:javascript复制
 # state, action, reward, next state iteration
    for episode in range(episode_count):
        state = q_world.reset()
        done = False
        print_episode(episode, delay=delay)
        while not done:
            action = q_world.act()
            next_state, reward, done = q_world.step(action)
            q_world.update_q_table(state, action, reward, next_state)
            print_status(q_world, done, step, delay=delay)
            state = next_state
            # if episode is done, perform housekeeping
            if done:
                if q_world.is_in_win_state():
                    wins  = 1
                    scores.append(step)
                    if wins > maxwins:
                        print(scores)
                        exit(0)
                # Exploration-Exploitation is updated every episode
                q_world.update_epsilon()
                step = 1
            else:
                step  = 1 

“图 9.3.9”显示了maxwins = 2000(达到2000 x目标状态)和delay = 0时的屏幕截图。 要仅查看最终的 Q 表,请执行:

代码语言:javascript复制
python3 q-learning-9.3.1.py --train 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cZSoTcOI-1681704311679)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_10.png)]

图 9.3.9:屏幕快照显示智能体在 2,000 次获胜后的 Q 表

Q 表已收敛,并显示了智能体可以在给定状态下采取的逻​​辑操作。 例如,在第一行或状态(0, 0)中,该策略建议向右移动。 第二行的状态(0, 1)也是如此。 第二个动作达到目标状态。 scores变量转储显示,随着智能体从策略获取正确的操作,所采取的最少步骤数减少了。

从“图 9.3.9”,我们可以从“公式 9.2.2”和V*(s) = max[a] Q(s, a)计算每个状态的值。 例如,对于状态(0, 0)V*(s) = max[a](0.0, 72.9, 90.0, 81.0) = 9.0

“图 9.3.10”显示每种状态的值。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6c4nqWBD-1681704311679)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_11.png)]

图 9.3.10:图 9.3.9 和公式 9.2.2 中每个状态的值

这个简单的示例说明了在简单确定性世界中智能体的 Q 学习的所有元素。 在下一节中,我们将介绍考虑随机性所需的轻微修改。

4. 非确定性环境

如果环境不确定,则奖励和行动都是概率性的。 新系统是随机的 MDP。 为了反映不确定性报酬,新的值函数为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xEwh4xsm-1681704311679)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_042.png)] (Equation 9.4.1)

贝尔曼方程修改为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bA9ooaj2-1681704311679)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_043.png)] (Equation 9.4.2)

但是,在本章中,我们将重点介绍确定性环境。 在下一节中,我们将提出一种更通用的 Q 学习算法,称为时差TD)学习。

5. 时差学习

Q 学习是更广义的 TD 学习TD(λ)的特例。 更具体地说,这是单步 TD 学习的特殊情况,TD(0)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jJKccuLk-1681704311680)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_045.png)] (Equation 9.5.1)

其中α是学习率。 注意,当α = 1,“公式 9.5.1”与贝尔曼等式相似。 为简单起见,我们还将“公式 9.5.1”称为 Q 学习或广义 Q 学习。

以前,我们将 Q 学习称为一种非策略性 RL 算法,因为它学习 Q 值函数而没有直接使用它尝试优化的策略。 上策略一步式 TD 学习算法的示例是 SARSA,类似于“公式 9.5.1”:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DEe9pXs0-1681704311680)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_048.png)] (Equation 9.5.2)

主要区别是使用已优化的策略来确定a'。 必须知道项sars'a'(因此名称为 SARSA)才能在每次迭代时更新 Q 值函数。 Q 学习和 SARSA 都在 Q 值迭代中使用现有的估计,该过程称为自举。 在引导过程中,我们从奖励中更新当前的 Q 值估计,并随后更新 Q 值估计。

在提出另一个示例之前,似乎需要合适的 RL 模拟环境。 否则,我们只能对非常简单的问题(如上一个示例)运行 RL 模拟。 幸运的是,OpenAI 创建了 Gym,我们将在下一节中介绍。

在 OpenAI Gym 上进行 Q 学习

OpenAI Gym 是的工具包,用于开发和比较 RL 算法。 它适用于大多数 DL 库,包括tf.keras。 可以通过运行以下命令来安装健身房:

代码语言:javascript复制
sudo pip3 install gym 

该体育馆有多种可以测试 RL 算法的环境,例如玩具文字,经典控件,算法,Atari 和二维/三维机器人。 例如,FrozenLake-v0(“图 9.5.1”)是一个玩具文本环境,类似于在 Python Q 学习示例中使用的简单确定性世界:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1Us8LbYq-1681704311680)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_12.png)]

图 9.5.1:OpenAI Gym 中的 FrozenLake-v0 环境

FrozenLake-v0具有 12 个状态,标记为S的状态为起始状态,F的状态为湖泊的冰冻部分,这是安全的,H为安全状态。 应当避免的空穴状态,G是飞盘所在的目标状态。 转换为目标状态的奖励为 1。 对于所有其他状态,奖励为

FrozenLake-v0中,还有四个可用动作(左,下,右,上),称为动作空间。 但是,与之前的简单确定性世界不同,实际运动方向仅部分取决于所选的动作。 FrozenLake-v0环境有两种变体。 滑和不滑。 不出所料,滑动模式更具挑战性。

应用于FrozenLake-v0的操作将返回观察结果(等效于下一个状态),奖励,完成(无论剧集是否完成)以及调试信息字典。 返回的观察对象捕获环境的可观察属性,称为观察空间。

通用 Q 学习可以应用于FrozenLake-v0环境。“表 9.5.1”显示了湿滑和非湿滑环境的表现改进。 衡量策略表现的一种方法是执行的事件达到目标状态的百分比。 百分比越高,效果越好。 从大约 1.5% 的纯探查(随机操作)的基准来看,该策略可以在非光滑环境中达到约 76% 的目标状态,在光滑环境中可以达到约 71% 的目标状态。 不出所料,很难控制湿滑的环境。

模式

运行

大约百分比的目标

训练非滑动

python3 q-frozenlake-9.5.1.py

26

测试非滑动

python3 q-frozenlake-9.5.1.py -d

76

纯随机动作非滑动

python3 q-frozenlake-9.5.1.py -e

1.5

训练滑动

python3 q-frozenlake-9.5.1.py -s

26

测试滑动

python3 q-frozenlake-9.5.1.py -s -d

71

纯随机动作滑动

python3 q-frozenlake-9.5.1.py -s -e

1.5

表 9.5.1:在 FrozenLake-v0 环境中学习率为 0.5 的广义 Q 学习的基线和表现

由于该代码仅需要一个 Q 表,因此仍可以在 Python 和 NumPy 中实现。“列表 9.5.1”显示了QAgent类的实现。 除了使用 OpenAI Gym 的FrozenLake-v0环境之外,最重要的更改是广义 Q 学习的实现,这由update_q_table()函数中的“公式 9.5.1”定义。

“列表 9.5.1”:q-frozenlake-9.5.1.py

关于 FrozenLake-v0 环境的 Q 学习:

代码语言:javascript复制
from collections import deque
import numpy as np
import argparse
import os
import time
import gym
from gym import wrappers, logger 
代码语言:javascript复制
class QAgent:
    def __init__(self,
                 observation_space,
                 action_space,
                 demo=False,
                 slippery=False,
                 episodes=40000):
        """Q-Learning agent on FrozenLake-v0 environment 
代码语言:javascript复制
 Arguments:
            observation_space (tensor): state space
            action_space (tensor): action space
            demo (Bool): whether for demo or training
            slippery (Bool): 2 versions of FLv0 env
            episodes (int): number of episodes to train
        """ 
代码语言:javascript复制
 self.action_space = action_space
        # number of columns is equal to number of actions
        col = action_space.n
        # number of rows is equal to number of states
        row = observation_space.n
        # build Q Table with row x col dims
        self.q_table = np.zeros([row, col]) 
代码语言:javascript复制
 # discount factor
        self.gamma = 0.9 
代码语言:javascript复制
 # initially 90% exploration, 10% exploitation
        self.epsilon = 0.9
        # iteratively applying decay til 
        # 10% exploration/90% exploitation
        self.epsilon_min = 0.1
        self.epsilon_decay = self.epsilon_min / self.epsilon
        self.epsilon_decay = self.epsilon_decay ** 
                             (1. / float(episodes)) 
代码语言:javascript复制
 # learning rate of Q-Learning
        self.learning_rate = 0.1 
代码语言:javascript复制
 # file where Q Table is saved on/restored fr
        if slippery:
            self.filename = 'q-frozenlake-slippery.npy'
        else:
            self.filename = 'q-frozenlake.npy' 
代码语言:javascript复制
 # demo or train mode 
        self.demo = demo
        # if demo mode, no exploration
        if demo:
            self.epsilon = 0 
代码语言:javascript复制
 def act(self, state, is_explore=False):
        """determine the next action
            if random, choose from random action space
            else use the Q Table
        Arguments:
            state (tensor): agent's current state
            is_explore (Bool): exploration mode or not
        Return:
            action (tensor): action that the agent
                must execute
        """
        # 0 - left, 1 - Down, 2 - Right, 3 - Up
        if is_explore or np.random.rand() < self.epsilon:
            # explore - do random action
            return self.action_space.sample() 
代码语言:javascript复制
 # exploit - choose action with max Q-value
        action = np.argmax(self.q_table[state])
        return action 
代码语言:javascript复制
 def update_q_table(self, state, action, reward, next_state):
        """TD(0) learning (generalized Q-Learning) with learning rate
        Arguments:
            state (tensor): environment state
            action (tensor): action executed by the agent for
                the given state
            reward (float): reward received by the agent for
                executing the action
            next_state (tensor): the environment next state
        """
        # Q(s, a)  = 
        # alpha * (reward   gamma * max_a' Q(s', a') - Q(s, a))
        q_value = self.gamma * np.amax(self.q_table[next_state])
        q_value  = reward
        q_value -= self.q_table[state, action]
        q_value *= self.learning_rate
        q_value  = self.q_table[state, action]
        self.q_table[state, action] = q_value 
代码语言:javascript复制
 def update_epsilon(self):
        """adjust epsilon"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay 

“列表 9.5.2”演示了智能体的感知行为学习循环。 在每个剧集中,通过调用env.reset()重置环境。 要执行的动作由agent.act()选择,并由env.step(action)应用于环境。 奖励和下一个状态将被观察并用于更新 Q 表。

在每个动作之后,通过agent.update_q_table()执行 TD 学习。 由于每次调用agent.update_epsilon()时处self.epsilon变量的值都会减少,该智能体开始支持利用 Q 表来确定在给定状态下执行的操作。 达到目标或空洞状态后,剧集完成(done = True)。 对于此示例,TD 学习运行 4,000 集。

“列表 9.5.2”:q-frozenlake-9.5.1.py

FrozenLake-v0环境的 Q 学习循环:

代码语言:javascript复制
 # loop for the specified number of episode
    for episode in range(episodes):
        state = env.reset()
        done = False
        while not done:
            # determine the agent's action given state
            action = agent.act(state, is_explore=args.explore)
            # get observable data
            next_state, reward, done, _ = env.step(action)
            # clear the screen before rendering the environment
            os.system('clear')
            # render the environment for human debugging
            env.render()
            # training of Q Table
            if done:
                # update exploration-exploitation ratio
                # reward > 0 only when Goal is reached
                # otherwise, it is a Hole
                if reward > 0:
                    wins  = 1 
代码语言:javascript复制
 if not args.demo:
                agent.update_q_table(state,
                                     action, 
                                     reward, 
                                     next_state)
                agent.update_epsilon() 
代码语言:javascript复制
 state = next_state
            percent_wins = 100.0 * wins / (episode   1) 

agent对象可以在湿滑或非湿滑模式下运行。 训练后,智能体可以利用 Q 表选择给定任何策略执行的操作,如“表 9.5.1”的测试模式所示。 如“表 9.5.1”所示,使用学习的策略可显着提高性能。 随着体育馆的使用,不再需要中构建环境的许多代码行。 例如,与上一个示例不同,使用 OpenAI Gym,我们不需要创建状态转换表和奖励表。

这将帮助我们专注于构建有效的 RL 算法。 要以慢动作方式运行代码或每个动作延迟 1 秒,请执行以下操作:

代码语言:javascript复制
python3 q-frozenlake-9.5.1.py -d -t=1 

在本节中,我们在更具挑战性的环境中演示了 Q 学习。 我们还介绍了 OpenAI 体育馆。 但是,我们的环境仍然是玩具环境。 如果我们有大量的状态或动作怎么办? 在这种情况下,使用 Q 表不再可行。 在下一节中,我们将使用深度神经网络来学习 Q 表。

6. 深度 Q 网络(DQN)

在小型离散环境中,使用 Q 表执行 Q 学习是很好的选择。 但是,在大多数情况下,当环境具有许多状态或连续时,Q 表是不可行或不实际的。 例如,如果我们观察由四个连续变量组成的状态,则表的大小是无限的。 即使我们尝试将这四个变量离散化为 1,000 个值,表中的总行数也达到了惊人的1000^4 = 1e12。 即使经过训练,该表仍是稀疏的–该表中的大多数单元都是零。

这个问题的解决方案称为 DQN [2],它使用深度神经网络来近似 Q 表,如图“图 9.6.1”所示。 有两种构建 Q 网络的方法:

  • 输入是状态-动作对,预测是 Q 值
  • 输入是状态,预测是每个动作的 Q 值

第一种选择不是最佳的,因为网络被调用的次数等于操作数。 第二种是首选方法。 Q 网络仅被调用一次。

最希望得到的作用就是 Q 值最大的作用。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1XO9ENsE-1681704311680)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_13.png)]

图 9.6.1:深度 Q 网络

训练 Q 网络所需的数据来自智能体的经验:(s[0]a[0]r[1]s[1], s[1]a[1]r[2]s[2],d ..., s[T-1]a[T-1]r[T]s[T])。 每个训练样本都是经验单元s[t]a[t]r[t 1]s[t 1]。 在时间步ts = s[t]的给定状态下,使用类似于前一部分的 Q 学习算法来确定动作a = a[t]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qTIUWymM-1681704311680)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_060.png)] (Equation 9.6.1)

为了简化符号,我们省略了下标和粗体字母的使用。 注意,Q(s, a)是 Q 网络。 严格来说,它是Q(a | s),因为动作已移至预测阶段(换句话说,是输出),如“图 9.6.1”的右侧所示。 Q 值最高的动作是应用于环境以获得奖励r = r[t 1],下一状态s' = s[t 1]和布尔值done的动作,指示下一个状态是否为终端 。 根据关于广义 Q 学习的“公式 9.5.1”,可以通过应用所选的操作来确定 MSE 损失函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LN8ly63Y-1681704311681)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_065.png)] (Equation 9.6.2)

在前面有关 Q 学习和Q(a | s) -> Q(s, a)的讨论中,所有项都很熟悉。 项max[a'] Q(a' | s') -> max[a'] Q(s', a')。 换句话说,使用 Q 网络,在给定下一个状态的情况下预测每个动作的 Q 值,并从其中获得最大值。 注意,在终端状态下,s'max[a'] Q(a' | s') -> max[a'] Q(s', a') = 0

但是,事实证明训练 Q 网络是不稳定的。 导致不稳定的问题有两个:1)样本之间的相关性高; 2)非平稳目标。 高度相关性是由于采样经验的顺序性质。 DQN 通过创建经验缓冲解决了问题。 训练数据是从该缓冲区中随机采样的。 此过程称为经验回放

非固定目标的问题是由于目标网络Q(s', a')在每小批训练后都会被修改。 目标网络的微小变化会导致策略,数据分布以及当前 Q 值和目标 Q 值之间的相关性发生重大变化。 这可以通过冻结C训练步骤的目标网络的权重来解决。 换句话说,创建了两个相同的 Q 网络。 在每个C训练步骤中,从训练中的 Q 网络复制目标 Q 网络参数。

“算法 9.6.1”中概述了深度 Q 网络算法。

“算法 9.6.1”: DQN 算法

要求:将重播内存D初始化为容量N

要求:使用随机权重θ初始化动作值函数Q

要求:使用权重θ- = 0初始化目标操作值函数Q_target

需要:探索率ε和折扣系数γ

  1. 对于episode = 1, ..., M,执行:
  2. 给定初始状态s
  3. 对于step = 1, ..., T,执行:
  4. 选择动作 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Fvaj6Eh3-1681704311681)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_082.png)]
  5. 执行动作a,观察奖励r,以及下一个状态s'
  6. 将转换(s, a, r, s')存储在D
  7. 更新状态s = s'
  8. 经验回放
  9. D中抽样一小部分经验(s[j], a[j], r[j 1], s[j 1])
  10. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MLK0eigp-1681704311681)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_090.png)]
  11. (Q_max - Q(s[j], a[j]; θ))²上相对于参数θ执行梯度下降步骤。
  12. 定期更新目标网络
  13. C个步骤,即Q_target = Q,换句话说,设置θ- = θ
  14. end
  15. end

“算法 9.6.1”总结了在具有离散动作空间和连续状态空间的环境上实现 Q 学习所需的所有技术。 在下一节中,我们将演示如何在更具挑战性的 OpenAI Gym 环境中使用 DQN。

Keras 中的 DQN

为了说明 DQN,使用了 OpenAI Gym 的CartPole-v0环境。 CartPole-v0是极点平衡问题。 目的是防止电杆跌落。 环境是二维的。 动作空间由两个离散的动作(左右移动)组成。 但是,状态空间是连续的,并且包含四个变量:

  • 直线位置
  • 线速度
  • 旋转角度
  • 角速度

CartPole-v0环境如图 9.6.1 所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZqzLrler-1681704311681)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_14.png)]

图 9.6.1:CartPole-v0 环境

最初,杆是直立的。 杆保持直立的每个时间步长都提供 1 的奖励。 当极点与垂直方向的夹角超过 15 度或与中心的距离超过 2.4 单位时,剧集结束。 如果在 100 个连续试验中平均奖励为 195.0,则认为CartPole-v0问题已解决:

“列表 9.6.1”向我们展示了CartPole-v0的 DQN 实现。 DQNAgent类表示使用 DQN 的智能体。 创建了两个 Q 网络:

  • “算法 9.6.1”中的 Q 网络或 Q
  • “算法 9.6.1”中的目标 Q 网络或Q_target

两个网络都是 MLP,每个都有 256 个单元的 3 个隐藏层。 这两个网络都是通过build_model()方法创建的。 在经验回放replay()期间训练 Q 网络。 以update_weights()的固定间隔C = 10个训练步骤,将 Q 网络参数复制到目标 Q 网络。 在“算法 9.6.1”中,这实现了第 13 行,Q_target = Q。 每次发作后,update_epsilon()都会降低探索利用的比例,以利用已学习的策略。

“列表 9.6.1”:dqn-cartpole-9.6.1.py

tf.keras中的 DQN:

代码语言:javascript复制
class DQNAgent:
    def __init__(self,
                 state_space,
                 action_space,
                 episodes=500):
        """DQN Agent on CartPole-v0 environment 
代码语言:javascript复制
 Arguments:
            state_space (tensor): state space
            action_space (tensor): action space
            episodes (int): number of episodes to train
        """
        self.action_space = action_space 
代码语言:javascript复制
 # experience buffer
        self.memory = [] 
代码语言:javascript复制
 # discount rate
        self.gamma = 0.9 
代码语言:javascript复制
 # initially 90% exploration, 10% exploitation
        self.epsilon = 1.0
        # iteratively applying decay til 
        # 10% exploration/90% exploitation
        self.epsilon_min = 0.1
        self.epsilon_decay = self.epsilon_min / self.epsilon
        self.epsilon_decay = self.epsilon_decay ** 
                             (1. / float(episodes)) 
代码语言:javascript复制
 # Q Network weights filename
        self.weights_file = 'dqn_cartpole.h5'
        # Q Network for training
        n_inputs = state_space.shape[0]
        n_outputs = action_space.n
        self.q_model = self.build_model(n_inputs, n_outputs)
        self.q_model.compile(loss='mse', optimizer=Adam())
        # target Q Network
        self.target_q_model = self.build_model(n_inputs, n_outputs)
        # copy Q Network params to target Q Network
        self.update_weights() 
代码语言:javascript复制
 self.replay_counter = 0
        self.ddqn = True if args.ddqn else False 
代码语言:javascript复制
 def build_model(self, n_inputs, n_outputs):
        """Q Network is 256-256-256 MLP 
代码语言:javascript复制
 Arguments:
            n_inputs (int): input dim
            n_outputs (int): output dim 
代码语言:javascript复制
 Return:
            q_model (Model): DQN
        """
        inputs = Input(shape=(n_inputs, ), name='state')
        x = Dense(256, activation='relu')(inputs)
        x = Dense(256, activation='relu')(x)
        x = Dense(256, activation='relu')(x)
        x = Dense(n_outputs,
                  activation='linear',
                  name='action')(x)
        q_model = Model(inputs, x)
        q_model.summary()
        return q_model 
代码语言:javascript复制
 def act(self, state):
        """eps-greedy policy
        Return:
            action (tensor): action to execute
        """
        if np.random.rand() < self.epsilon:
            # explore - do random action
            return self.action_space.sample() 
代码语言:javascript复制
 # exploit
        q_values = self.q_model.predict(state)
        # select the action with max Q-value
        action = np.argmax(q_values[0])
        return action 
代码语言:javascript复制
 def remember(self, state, action, reward, next_state, done):
        """store experiences in the replay buffer
        Arguments:
            state (tensor): env state
            action (tensor): agent action
            reward (float): reward received after executing
                action on state
            next_state (tensor): next state
        """
        item = (state, action, reward, next_state, done)
        self.memory.append(item) 
代码语言:javascript复制
 def get_target_q_value(self, next_state, reward):
        """compute Q_max
           Use of target Q Network solves the 
            non-stationarity problem
        Arguments:
            reward (float): reward received after executing
                action on state
            next_state (tensor): next state
        Return:
            q_value (float): max Q-value computed by
                DQN or DDQN
        """
        # max Q value among next state's actions
        if self.ddqn:
            # DDQN
            # current Q Network selects the action
            # a'_max = argmax_a' Q(s', a')
            action = np.argmax(self.q_model.predict(next_state)[0])
            # target Q Network evaluates the action
            # Q_max = Q_target(s', a'_max)
            q_value = self.target_q_model.predict(
                                          next_state)[0][action]
        else:
            # DQN chooses the max Q value among next actions
            # selection and evaluation of action is 
            # on the target Q Network
            # Q_max = max_a' Q_target(s', a')
            q_value = np.amax(
                      self.target_q_model.predict(next_state)[0]) 
代码语言:javascript复制
 # Q_max = reward   gamma * Q_max
        q_value *= self.gamma
        q_value  = reward
        return q_value 
代码语言:javascript复制
 def replay(self, batch_size):
        """experience replay addresses the correlation issue 
            between samples
        Arguments:
            batch_size (int): replay buffer batch 
                sample size
        """
        # sars = state, action, reward, state' (next_state)
        sars_batch = random.sample(self.memory, batch_size)
        state_batch, q_values_batch = [], [] 
代码语言:javascript复制
 # fixme: for speedup, this could be done on the tensor level
        # but easier to understand using a loop
        for state, action, reward, next_state, done in sars_batch:
            # policy prediction for a given state
            q_values = self.q_model.predict(state) 
代码语言:javascript复制
 # get Q_max
            q_value = self.get_target_q_value(next_state, reward) 
代码语言:javascript复制
 # correction on the Q value for the action used
            q_values[0][action] = reward if done else q_value 
代码语言:javascript复制
 # collect batch state-q_value mapping
            state_batch.append(state[0])
            q_values_batch.append(q_values[0]) 
代码语言:javascript复制
 # train the Q-network
        self.q_model.fit(np.array(state_batch),
                         np.array(q_values_batch),
                         batch_size=batch_size,
                         epochs=1,
                         verbose=0) 
代码语言:javascript复制
 # update exploration-exploitation probability
        self.update_epsilon() 
代码语言:javascript复制
 # copy new params on old target after 
        # every 10 training updates
        if self.replay_counter % 10 == 0:
            self.update_weights() 
代码语言:javascript复制
 self.replay_counter  = 1 
代码语言:javascript复制
 def update_epsilon(self):
        """decrease the exploration, increase exploitation"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay 

为了在“算法 9.6.1”经验回放replay()中实现第 10 行,对于每个体验单元(s[j]a[j]r[j 1]s[j 1])将动作a[j]的 Q 值设置为Q_max。 所有其他动作的 Q 值保持不变。

这是通过 DQNAgent replay()函数中的以下行实现的:

代码语言:javascript复制
# policy prediction for a given state q_values = self.q_model.predict(state)
# get Q_max
q_value = self.get_target_q_value(next_state)
# correction on the Q value for the action used q_values[0][action] = reward if done else q_value 

如“算法 9.6.1”的第 11 行所示,只有动作a[j]具有等于(Q_max - Q(s[j], a[j]; θ))²的非零损失。 请注意,假设缓冲区中有足够的数据,换句话说,在每个剧集结束后,“列表 9.6.2”中的感知动作学习循环会调用经验回放。 缓冲区的大小大于或等于批量大小)。 在经验回放期间,会随机采样一批体验单元,并将其用于训练 Q 网络。

与 Q 表类似,act()实现了 ε-贪婪策略,“公式 9.6.1”。

体验由remember()存储在重播缓冲区中。 Q 通过get_target_q_value()函数计算。

“列表 9.6.2”总结了智能体的感知-行动-学习循环。 在每个剧集中,通过调用env.reset()重置环境。 要执行的动作由agent.act()选择,并由env.step(action)应用于环境。 奖励和下一状态将被观察并存储在重播缓冲区中。 在执行每个操作之后,智能体会调用replay()来训练 DQN 并调整探索利用比率。

当极点与垂直方向的夹角超过 15 度或与中心的距离超过 2.4 单位时,剧集完成(done = True)。 对于此示例,如果 DQN 智能体无法解决问题,则 Q 学习最多运行 3,000 集。 如果average mean_score奖励在 100 次连续试验win_trials中为 195.0,则认为CartPole-v0问题已解决。

“列表 9.6.2”:dqn-cartpole-9.6.1.py

tf.keras中的 DQN 训练循环:

代码语言:javascript复制
 # Q-Learning sampling and fitting
    for episode in range(episode_count):
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        done = False
        total_reward = 0
        while not done:
            # in CartPole-v0, action=0 is left and action=1 is right
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            # in CartPole-v0:
            # state = [pos, vel, theta, angular speed]
            next_state = np.reshape(next_state, [1, state_size])
            # store every experience unit in replay buffer
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward  = reward 
代码语言:javascript复制
 # call experience relay
        if len(agent.memory) >= batch_size:
            agent.replay(batch_size) 
代码语言:javascript复制
 scores.append(total_reward)
        mean_score = np.mean(scores)
        if mean_score >= win_reward[args.env_id] 
                and episode >= win_trials:
            print("Solved in episode %d: 
                   Mean survival = %0.2lf in %d episodes"
                  % (episode, mean_score, win_trials))
            print("Epsilon: ", agent.epsilon)
            agent.save_weights()
            break
        if (episode   1) % win_trials == 0:
            print("Episode %d: Mean survival = 
                   %0.2lf in %d episodes" %
                  ((episode   1), mean_score, win_trials)) 

在平均 10 次运行的中,DQN 在 822 集内解决了。 我们需要注意的是,每次训练运行的结果可能会有所不同。

自从引入 DQN 以来,连续的论文都提出了对“算法 9.6.1”的改进。 一个很好的例子是双 DQN(DDQN),下面将对其进行讨论。

双重 Q 学习(DDQN)

在 DQN 中,目标 Q 网络选择并评估每个动作,从而导致 Q 值过高。 为了解决这个问题,DDQN [3]建议使用 Q 网络选择动作,并使用目标 Q 网络评估动作。

在 DQN 中,如“算法 9.6.1”所概述,第 10 行中 Q 值的估计为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FrJ5Y9SZ-1681704311682)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_097.png)]

  • Q_target选择并评估动作,a[j 1]

DDQN 建议将第 10 行更改为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FyfxAL35-1681704311682)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_100.png)]

argmax[a[j 1]] Q(s[j 1], a[j 1]; θ)使 Q 函数可以选择动作。 然后,该动作由Q_target评估。

“列表 9.6.3”显示了当我们创建一个新的DDQNAgent类时,该类继承自DQNAgent类。 只有get_target_q_value()方法被覆盖,以实现最大 Q 值计算中的更改。

“列表 9.6.3”:dqn-cartpole-9.6.1.py

代码语言:javascript复制
class DDQNAgent(DQNAgent):
    def __init__(self,
                 state_space,
                 action_space,
                 episodes=500):
        super().__init__(state_space,
                         action_space,
                         episodes)
        """DDQN Agent on CartPole-v0 environment 
代码语言:javascript复制
 Arguments:
            state_space (tensor): state space
            action_space (tensor): action space
            episodes (int): number of episodes to train
        """ 
代码语言:javascript复制
 # Q Network weights filename
        self.weights_file = 'ddqn_cartpole.h5' 
代码语言:javascript复制
 def get_target_q_value(self, next_state, reward):
        """compute Q_max
           Use of target Q Network solves the 
            non-stationarity problem
        Arguments:
            reward (float): reward received after executing
                action on state
            next_state (tensor): next state
        Returns:
            q_value (float): max Q-value computed
        """
        # max Q value among next state's actions
        # DDQN
        # current Q Network selects the action
        # a'_max = argmax_a' Q(s', a')
        action = np.argmax(self.q_model.predict(next_state)[0])
        # target Q Network evaluates the action
        # Q_max = Q_target(s', a'_max)
        q_value = self.target_q_model.predict(
                                      next_state)[0][action] 
代码语言:javascript复制
 # Q_max = reward   gamma * Q_max
        q_value *= self.gamma
        q_value  = reward
        return q_value 

为了进行比较,在平均 10 次运行中,CartPole-v0由 DDQN 在 971 个剧集中求解。 要使用 DDQN,请运行以下命令:

代码语言:javascript复制
python3 dqn-cartpole-9.6.1.py -d 

DQN 和 DDQN 均表明,借助 DL,Q 学习能够扩展并解决具有连续状态空间和离散动作空间的问题。 在本章中,我们仅在具有连续状态空间和离散动作空间的最简单问题之一上演示了 DQN。 在原始论文中,DQN [2]证明了它可以在许多 Atari 游戏中达到超人的表现水平。

7. 总结

在本章中,我们已经介绍了 DRL,DRL 是一种强大的技术,许多研究人员认为它是 AI 的最有希望的领先者。 我们已经超越了 RL 的原则。 RL 能够解决许多玩具问题,但是 Q 表无法扩展到更复杂的现实问题。 解决方案是使用深度神经网络学习 Q 表。 但是,由于样本相关性和目标 Q 网络的非平稳性,在 RL 上训练深度神经网络非常不稳定。

DQN 提出了一种使用经验回放并将目标网络与受训 Q 网络分离的解决方案。 DDQN 建议通过将动作选择与动作评估分开来最大程度地降低 Q 值,从而进一步改进算法。 DQN 还提出了其他改进建议。 优先经验回放[6]认为,不应对体验缓冲区进行统一采样。

取而代之的是,应更频繁地采样基于 TD 误差的更重要的经验,以完成更有效的训练。 文献[7]提出了一种对决网络架构来估计状态值函数和优势函数。 这两个函数均用于估计 Q 值,以加快学习速度。

本章介绍的方法是值迭代/拟合。 通过找到最佳值函数间接学习策略。 在下一章中,方法将是使用称为策略梯度方法的一系列算法直接学习最佳策略。 学习策略有很多好处。 特别地,策略梯度方法可以处理离散和连续的动作空间。

8. 参考

  1. Sutton and Barto: Reinforcement Learning: An Introduction, 2017 (http://incompleteideas.net/book/bookdraft2017nov5.pdf).
  2. Volodymyr Mnih et al.: Human-level Control through Deep Reinforcement Learning. Nature 518.7540, 2015: 529 (http://www.davidqiu.com:8888/research/nature14236.pdf).
  3. Hado Van Hasselt, Arthur Guez, and David Silver: Deep Reinforcement Learning with Double Q-Learning. AAAI. Vol. 16, 2016 (http://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/download/12389/11847).
  4. Kai Arulkumaran et al.: A Brief Survey of Deep Reinforcement Learning. arXiv preprint arXiv:1708.05866, 2017 (https://arxiv.org/pdf/1708.05866.pdf).
  5. David Silver: Lecture Notes on Reinforcement Learning (http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html).
  6. Tom Schaul et al.: Prioritized experience replay. arXiv preprint arXiv:1511.05952, 2015 (https://arxiv.org/pdf/1511.05952.pdf).
  7. Ziyu Wang et al.: Dueling Network Architectures for Deep Reinforcement Learning. arXiv preprint arXiv:1511.06581, 2015 (https://arxiv.org/pdf/1511.06581.pdf).

十、策略梯度方法

在本章中,我们将介绍在强化学习中直接优化策略网络的算法。 这些算法统称为“策略梯度方法”。 由于策略网络是在训练期间直接优化的,因此策略梯度方法属于基于策略强化学习算法的族。 就像我们在“第 9 章”,“深度强化学习”中讨论的基于值的方法一样,策略梯度方法也可以实现为深度强化学习算法。

研究策略梯度方法的基本动机是解决 Q 学习的局限性。 我们会回想起 Q 学习是关于选择使状态值最大化的动作。 借助 Q 函数,我们能够确定策略,使智能体能够决定对给定状态采取何种操作。 选择的动作只是使智能体最大化的动作。 在这方面,Q 学习仅限于有限数量的离散动作。 它不能处理连续的动作空间环境。 此外,Q 学习不是直接优化策略。 最后,强化学习是要找到智能体能够使用的最佳策略,以便决定应采取何种行动以最大化回报。

相反,策略梯度方法适用于具有离散或连续动作空间的环境。 另外,我们将在本章中介绍的四种策略梯度方法是直接优化策略网络的表现度量。 这样就形成了一个经过训练的策略网络,智能体可以使用该网络来最佳地在其环境中采取行动。

总之,本章的目的是介绍:

  • 策略梯度定理
  • 四种策略梯度方法: REINFORCE带基线的 REINFORCE演员评论家优势演员评论家(A2C)
  • 在连续动作空间环境中如何在tf.keras中实现策略梯度方法的指南

让我们从定理开始。

1. 策略梯度定理

如“第 9 章”,“深度强化学习”中所讨论的,智能体位于环境中,处于状态s[t]中,它是状态空间S的一个元素。 状态空间S可以是离散的,也可以是连续的。 智能体通过遵循策略π(a[t], s[t])从动作空间A采取动作a[t]A可以是离散的或连续的。 作为执行动作a[t]的结果,智能体会收到奖励r[t 1],并且环境转换为新状态s[t 1]。 新状态仅取决于当前状态和操作。 智能体的目标是学习一种最佳策略π*,该策略可最大化所有状态的回报:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4dhnsaed-1681704311682)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_009.png)] (Equation 9.1.1)

收益R[t]定义为从时间t直到剧集结束或达到最终状态时的折扣累积奖励:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-66YE5TIV-1681704311682)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_010.png)] (Equation 9.1.2)

根据“公式 9.1.2”,还可以通过遵循策略π将返回解释为给定状态的值。 从“公式 9.1.1”可以看出,与通常的γ^k < 1.0相比,与立即奖励相比,未来奖励的权重较低。

到目前为止,我们仅考虑通过优化基于值的函数Q(s, a)来学习策略。

本章的目标是通过参数化π(a[t] | s[t]) -> π(a[t] | s[t], θ)直接学习该策略。 通过参数化,我们可以使用神经网络来学习策略函数。

学习策略意味着我们将最大化某个目标函数J(θ),这是相对于参数θ的一种表现度量。在间歇式强化学习中,表现度量是起始状态的值。 在连续的情况下,目标函数是平均奖励率。

通过执行梯度上升来最大化目标函数J(θ)。 在梯度上升中,梯度更新是在要优化的函数的导数方向上。 到目前为止,我们的所有损失函数都通过最小化或通过执行梯度下降进行了优化。 稍后,在tf.keras实现中,我们将看到可以通过简单地否定目标函数并执行梯度下降来执行梯度上升。

直接学习策略的好处是它可以应用于离散和连续动作空间。 对于离散的动作空间:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qsHrtRk8-1681704311682)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_019.png)] (Equation 10.1.1)

其中a[i]是第i个动作。 a[i]可以是神经网络的预测或状态作用特征的线性函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hMvweDc8-1681704311683)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_022.png)] (Equation 10.1.2)

φ(s[t], a[i])是将状态操作转换为特征的任何函数,例如编码器。

π(a[t] | s[t], θ)确定每个a[i]的概率。 例如,在上一章中的柱杆平衡问题中,目标是通过沿二维轴向左或向右移动柱车来保持柱子直立。 在这种情况下,a[0]a[1]分别是左右移动的概率。 通常,智能体以最高概率a[t] = max[i] π(a[t] | s[t], θ)采取行动。

对于连续动作空间,π(a[t] | s[t], θ)根据给定状态的概率分布对动作进行采样。 例如,如果连续动作空间在a[t] ∈ [-1.0, 1.0]范围内,则π(a[t] | s[t], θ)通常是高斯分布,其均值和标准差由策略网络预测。 预测动作是来自此高斯分布的样本。 为了确保不会生成任何无效的预测,该操作将被限制在 -1.0 和 1.0 之间。

正式地,对于连续的动作空间,该策略是高斯分布的样本:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-glWnECgF-1681704311683)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_032.png)] (Equation 10.1.3)

平均值μ和标准差σ都是状态特征的函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4P4g4Tow-1681704311683)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_035.png)] (Equation 10.1.4)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MKouiahR-1681704311683)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_036.png)] (Equation 10.1.5)

φ(s[i])是将状态转换为其特征的任何函数。 ζ(x) = log(1 e^x)是确保标准差为正值的softplus函数。 实现状态特征函数φ(s[t])的一种方法是使用自编码器网络的编码器。 在本章的最后,我们将训练一个自编码器,并将编码器部分用作状态特征。 因此,训练策略网络是优化参数的问题θ = [θ[μ], θ[σ]]

给定连续可微分的策略函数π(a[t] | s[t], θ),策略梯度可以计算为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N2QnBbkn-1681704311683)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_042.png)] (Equation 10.1.6)

“公式 10.1.6”也被称为策略梯度定理。 它适用于离散和连续动作空间。 根据通过 Q 值缩放的策略操作采样的自然对数,可以计算出相对于参数θ的梯度。“公式 10.1.6”利用了自然对数ᐁx/x = ᐁlnx的特性。

策略梯度定理在某种意义上是直观的,即表现梯度是根据目标策略样本估计的,并且与策略梯度成比例。 策略梯度由 Q 值缩放,以鼓励对状态值产生积极贡献的行动。 梯度还与动作概率成反比,以惩罚对提高性能没有贡献的频繁发生的动作。

有关策略梯度定理的证明,请参阅[2]和 David Silver 关于强化学习的讲义。

与策略梯度方法相关的细微优势。 例如,在某些基于纸牌的游戏中,与基于策略的方法不同,基于值的方法在处理随机性方面没有直接的过程。 在基于策略的方法中,操作概率随参数而平滑变化。

同时,相对于参数的微小变化,基于值的行为可能会发生剧烈变化。 最后,基于策略的方法对参数的依赖性使我们对如何执行表现考核的梯度提升产生了不同的表述。 这些是在后续部分中介绍的四种策略梯度方法。

基于策略的方法也有其自身的缺点。 由于趋向于收敛于局部最优而非全局最优,所以它们通常更难训练。 在本章末尾提出的实验中,智能体很容易适应并选择不一定提供最高值的动作。 策略梯度的特征还在于高差异。

梯度更新经常被高估。 此外,基于训练策略的方法非常耗时。 训练需要成千上万集(即采样效率不高)。 每个剧集仅提供少量样本。 在本章结尾处提供的实现方面的典型训练,大约需要一个小时才能在 GTX 1060 GPU 上进行 1,000 集。

在以下各节中,我们将讨论四种策略梯度方法。 虽然讨论的重点是连续的动作空间,但是该概念通常适用于离散的动作空间。

2. 蒙特卡洛策略梯度(REINFORCE)方法

最简单的策略梯度方法是 REINFORCE [4],这是蒙特卡洛策略梯度方法:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FM9FiOV8-1681704311684)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_045.png)] (Equation 10.2.1)

其中R[t]是返回值,如“公式 9.1.2”所定义。R[t]是策略梯度定理中Q^π(s[t], a[t])的无偏样本。

“算法 10.2.1”总结了 REINFORCE 算法[2]。 REINFORCE 是一种蒙特卡洛算法。 它不需要环境动态知识(换句话说,无需模型)。 仅需要经验样本(s[i], a[i], r[i 1], s[i 1])来优化策略网络π(a[t] | s[t])的参数。 折扣因子γ考虑到奖励随着步数增加而降低的事实。 梯度被γ^k打折。 在后续步骤中采用的梯度贡献较小。 学习率α是梯度更新的比例因子。

通过使用折扣梯度和学习率执行梯度上升来更新参数。 作为蒙特卡洛算法,REINFORCE 要求智能体在处理梯度更新之前先完成一集。 同样由于其蒙特卡洛性质,REINFORCE 的梯度更新具有高方差的特征。

算法 10.2.1 REINFORCE

要求:可微分的参数化目标策略网络π(a[t] | s[t], θ)

要求:折扣因子,γ = [0, 1]和学习率α。 例如,γ = 0.99α = 1e - 3

要求θ[0],初始策略网络参数(例如,θ[0] -> 0)。

  1. 重复。
  2. 通过跟随π(a[t] | s[t], θ)来生成剧集(s[0]a[0]r[1]s[1], s[1]a[1]r[2]s[2], ..., s[T-1]a[T-1]r[T]s[T])
  3. 对于步骤t = 0, ..., T - 1,执行:
  4. 计算返回值R[t] = Σ γ^t r[t k], k = 0, ..., T
  5. 计算折扣的表现梯度ᐁJ(θ) = r^t R[t] ᐁ[θ] ln π(a[t] | s[t], θ)
  6. 执行梯度上升θ = θ αᐁJ(θ)

在 REINFORCE 中,可以通过神经网络对参数化策略进行建模,如图“图 10.2.1”所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WMkCkQMS-1681704311684)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_01.png)]

图 10.2.1:策略网络

如上一节中讨论的,在连续动作空间的情况下,状态输入被转换为特征。 状态特征是策略网络的输入。 代表策略函数的高斯分布具有均值和标准差,均是状态特征的函数。 根据状态输入的性质,策略网络π(θ)可以是 MLP,CNN 或 RNN。 预测的动作只是策略函数的样本。

“列表 10.2.1”显示了REINFORCEAgent 类,该类在tf.keras中实现了“算法 10.2.1”。 train_by_episode()在剧集完成后调用,以计算每个步骤的回报。 train()通过针对目标函数logp_model优化网络来执行“算法 10.2.1”的第 5 行和第 6 行。 父类PolicyAgent在本章介绍的四种策略梯度方法的算法中实现了的通用代码。 在讨论所有策略梯度方法之后,将介绍PolicyAgent

“列表 10.2.1”:policygradient-car-10.1.1.py

代码语言:javascript复制
class REINFORCEAgent(PolicyAgent):
    def __init__(self, env):
        """Implements the models and training of 
           REINFORCE policy gradient method
        Arguments:
            env (Object): OpenAI gym environment
        """
        super().__init__(env) 
代码语言:javascript复制
 def train_by_episode(self):
        """Train by episode
           Prepare the dataset before the step by step training
        """
        # only REINFORCE and REINFORCE with baseline
        # use the ff code
        # convert the rewards to returns
        rewards = []
        gamma = 0.99
        for item in self.memory:
            [_, _, _, reward, _] = item
            rewards.append(reward)

        # compute return per step
        # return is the sum of rewards from t til end of episode
        # return replaces reward in the list
        for i in range(len(rewards)):
            reward = rewards[i:]
            horizon = len(reward)
            discount =  [math.pow(gamma, t) for t in range(horizon)]
            return_ = np.dot(reward, discount)
            self.memory[i][3] = return_ 
代码语言:javascript复制
 # train every step
        for item in self.memory:
            self.train(item, gamma=gamma) 
代码语言:javascript复制
 def train(self, item, gamma=1.0):
        """Main routine for training 
        Arguments:
            item (list) : one experience unit
            gamma (float) : discount factor [0,1]
        """
        [step, state, next_state, reward, done] = item 
代码语言:javascript复制
 # must save state for entropy computation
        self.state = state 
代码语言:javascript复制
 discount_factor = gamma**step
        delta = reward 
代码语言:javascript复制
 # apply the discount factor as shown in Algorithms
        # 10. 2.1, 10.3.1 and 10.4.1
        discounted_delta = delta * discount_factor
        discounted_delta = np.reshape(discounted_delta, [-1, 1])
        verbose = 1 if done else 0 
代码语言:javascript复制
 # train the logp model (implies training of actor model
        # as well) since they share exactly the same set of
        # parameters
        self.logp_model.fit(np.array(state),
                            discounted_delta,
                            batch_size=1,
                            epochs=1,
                            verbose=verbose) 

以下部分提出了对 REINFORCE 方法的改进。

3. 带基线方法的 REINFORCE

REINFORCE 算法可以通过从收益δ = R[t] - B(s[t])中减去基线来概括。 基线函数B(s[t])可以是任何函数,只要它不依赖a[t]即可。 基线不会改变表现梯度的期望:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MxDLz8iG-1681704311684)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_069.png)] (Equation 10.3.1)

“公式 10.3.1”隐含E[π] [B(s[t]) ᐁ[θ] ln π(a[t] | s[t], θ)] = 0,因为B(s[t])不是a[t]的函数。 尽管引入基准不会改变期望值,但会减小梯度更新的方差。 方差的减少通常会加速学习。

在大多数情况下,我们使用值函数B(s[t]) = V(s[t])作为基准。 如果收益被高估,则比例系数将通过值函数成比例地减小,从而导致较低的方差。 值函数也已参数化V(s[t]) = V(s[t]; θ[v]),并与策略网络一起进行了训练。 在连续动作空间中,状态值可以是状态特征的线性函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N8ARAJHX-1681704311684)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_075.png)] (Equation 10.3.2)

“算法 10.3.1”用基线方法[1]总结了 REINFORCE。 这与 REINFORCE 相似,只不过将返回值替换为δ。 区别在于我们现在正在训练两个神经网络。

算法 10.3.1 带基线的 REINFORCE

要求:可微分的参数化目标策略网络π(a[t] | s[t], θ)

要求:可微分的参数化值网络V(s[t], θ[v])

要求:折扣因子γ ∈ [0, 1],表现梯度的学习率α和值梯度α[v]的学习率。

要求θ[0],初始策略网络参数(例如,θ[0] -> 0)。 θ[v0],初始值网络参数(例如θ[v0] -> 0)。

  1. 重复。
  2. 通过跟随π(a[t] | s[t], θ)来生成剧集(s[0]a[0]r[1]s[1], s[1]a[1]r[2]s[2], ..., a[T-1]a[T-1]r[T]s[T])
  3. 对于步骤t = 0, ..., T - 1,执行:
  4. 计算返回值: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yvETc34m-1681704311684)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_062.png)]
  5. 减去基线: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Bkt2DSm5-1681704311685)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_090.png)]
  6. 计算折扣值梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h6rdhNMi-1681704311685)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_091.png)]
  7. 执行梯度上升: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vYmSwQwr-1681704311685)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_092.png)]
  8. 计算折扣的表现梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nmigTqzj-1681704311685)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_093.png)]
  9. 执行梯度上升: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2ilYZbVU-1681704311685)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_064.png)]

如图“图 10.3.1”所示,除了策略网络π(θ)之外,值网络V(θ)也同时受到训练。 通过表现梯度ᐁJ(θ)更新策略网络参数,而通过梯度ᐁV(θ[v])调整值网络参数。 由于 REINFORCE 是蒙特卡罗算法,因此值函数训练也是蒙特卡罗算法。

学习率不一定相同。 请注意,值网络也在执行梯度上升。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-19Bw9alm-1681704311686)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_02.png)]

图 10.3.1:策略和值网络。 具有基线的 REINFORCE 具有一个计算基线的值网络

“列表 10.3.1”显示了REINFORCEBaselineAgent类,该类在tf.keras中实现了“算法 10.3.1”。 它继承自REINFORCEAgent,因为这两种算法仅在和train()方法上有所不同。 “算法 10.3.1”的第 5 行由delta = reward - self.value(state)[0]计算。 然后,通过调用各自模型的fit()方法来优化第 7 行和第 9 行中用于目标和值函数的网络logp_modelvalue_model

“列表 10.3.1”:policygradient-car-10.1.1.py

代码语言:javascript复制
class REINFORCEBaselineAgent(REINFORCEAgent):
    def __init__(self, env):
        """Implements the models and training of 
           REINFORCE w/ baseline policy 
           gradient method
        Arguments:
            env (Object): OpenAI gym environment
        """
        super().__init__(env) 
代码语言:javascript复制
 def train(self, item, gamma=1.0):
        """Main routine for training 
        Arguments:
            item (list) : one experience unit
            gamma (float) : discount factor [0,1]
        """
        [step, state, next_state, reward, done] = item 
代码语言:javascript复制
 # must save state for entropy computation
        self.state = state 
代码语言:javascript复制
 discount_factor = gamma**step 
代码语言:javascript复制
 # reinforce-baseline: delta = return - value
        delta = reward - self.value(state)[0] 
代码语言:javascript复制
 # apply the discount factor as shown in Algorithms
        # 10. 2.1, 10.3.1 and 10.4.1
        discounted_delta = delta * discount_factor
        discounted_delta = np.reshape(discounted_delta, [-1, 1])
        verbose = 1 if done else 0 
代码语言:javascript复制
 # train the logp model (implies training of actor model
        # as well) since they share exactly the same set of
        # parameters
        self.logp_model.fit(np.array(state),
                            discounted_delta,
                            batch_size=1,
                            epochs=1,
                            verbose=verbose) 
代码语言:javascript复制
 # train the value network (critic)
        self.value_model.fit(np.array(state),
                             discounted_delta,
                             batch_size=1,
                             epochs=1,
                             verbose=verbose) 

在的下一部分中,我们将介绍使用基准线方法对 REINFORCE 的改进。

4. 演员评论家方法

在带有基线的 REINFORCE 方法中,该值用作基线。 它不用于训练值函数。 在本节中,我们介绍 REINFORCE 与基线的变化,称为演员评论家方法。 策略和值网络扮演着参与者和批评者网络的角色。 策略网络是参与者决定给定状态时要采取的操作。 同时,值网络评估参与者或策略网络做出的决策。

值网络充当批评者的角色,可以量化参与者所选择的行动的好坏。 值网络通过将状态值V(s, θ[v]与收到的奖励r和观察到的下一个状态γV(s', θ[v])的折扣值之和来评估状态值。 差异δ表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uJ17N2he-1681704311686)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_103.png)] (Equation 10.4.1)

为了简单起见,我们在中删除了rs的下标。“公式 10.4.1”类似于“第 9 章”,“深度强化学习”中讨论的 Q 学习中的时间差异。 下一个状态值被γ = [0.0, 1.0]折扣。估计遥远的未来奖励很困难。 因此,我们的估计仅基于近期r γV(s', θ[v])。 这就是自举技术。

自举技术和“公式 10.4.1”中状态表示的依赖性通常会加速学习并减少差异。 从“公式 10.4.1”,我们注意到值网络评估了当前状态s = s[t],这是由于策略网络的上一个操作a[t-1]。 同时,策略梯度基于当前动作a[t]。 从某种意义上说,评估延迟了一步。

“算法 10.4.1”总结了演员评论家方法[1]。 除了评估用于训练策略和值网络的状态值评估外,还可以在线进行训练。 在每个步骤中,两个网络都经过训练。 这与 REINFORCE 和带有基线的 REINFORCE 不同,在基线之前,智能体完成了一个剧集。 首先,在当前状态的值估计期间向值网络查询两次,其次,为下一个状态的值查询。 这两个值都用于梯度计算中。

算法 10.4.1 演员评论家

要求:可微分的参数化目标策略网络π(a | s, θ)

要求:可微分的参数化值网络V(s, θ[v])

要求:折扣因子γ ∈ [0, 1],表现梯度的学习率α和值梯度α[v]的学习率。

要求θ[0],初始策略网络参数(例如,θ[0] -> 0)。 θ[v0],初始值网络参数(例如θ[v0] -> 0)。

  1. 重复。
  2. 对于步骤t = 0, ..., T - 1,执行:
  3. 对动作a ~ π(a | s, θ)进行采样。
  4. 执行动作并观察奖励r和下一个状态s'
  5. 评估状态值估计: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YoPe1Tdq-1681704311686)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_121.png)]
  6. 计算折扣值梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6GoyqBRO-1681704311686)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_122.png)]
  7. 执行梯度上升: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NzbAvYiT-1681704311686)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_092.png)]
  8. 计算折扣表现梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7BYpD5re-1681704311687)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_124.png)]
  9. 执行梯度上升: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XP1R3ip4-1681704311687)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_064.png)]
  10. s = s'

“图 10.4.1”显示了演员评论家网络:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-55ELCxlu-1681704311687)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_03.png)]

图 10.4.1:演员评论家网络。 通过对值V'的第二次评估,演员评论家与 REINFORCE 的基线有所不同

“列表 10.4.1”显示了ActorCriticAgent类,该类在tf.keras中实现了“算法 10.4.1”。 与两种 REINFORCE 方法不同,演员评论家不等待剧集完成。 因此,它没有实现train_by_episode()。 在每个体验单元,通过调用各自模型的fit()方法,优化第 7 行和第 9 行中用于目标和值函数logp_modelvalue_model的网络。 delta变量存储第 5 行的结果。

“列表 10.4.1”:policygradient-car-10.1.1.py

代码语言:javascript复制
class ActorCriticAgent(PolicyAgent):
    def __init__(self, env):
        """Implements the models and training of 
           Actor Critic policy gradient method
        Arguments:
            env (Object): OpenAI gym environment
        """
        super().__init__(env) 
代码语言:javascript复制
 def train(self, item, gamma=1.0):
        """Main routine for training
        Arguments:
            item (list) : one experience unit
            gamma (float) : discount factor [0,1]
        """
        [step, state, next_state, reward, done] = item 
代码语言:javascript复制
 # must save state for entropy computation
        self.state = state 
代码语言:javascript复制
 discount_factor = gamma**step 
代码语言:javascript复制
 # actor-critic: delta = reward - value 
        #         discounted_next_value
        delta = reward - self.value(state)[0] 
代码语言:javascript复制
 # since this function is called by Actor-Critic
        # directly, evaluate the value function here
        if not done:
            next_value = self.value(next_state)[0]
            # add  the discounted next value
            delta  = gamma*next_value 
代码语言:javascript复制
 # apply the discount factor as shown in Algortihms
        # 10. 2.1, 10.3.1 and 10.4.1
        discounted_delta = delta * discount_factor
        discounted_delta = np.reshape(discounted_delta, [-1, 1])
        verbose = 1 if done else 0 
代码语言:javascript复制
 # train the logp model (implies training of actor model
        # as well) since they share exactly the same set of
        # parameters
        self.logp_model.fit(np.array(state),
                            discounted_delta,
                            batch_size=1,
                            epochs=1,
                            verbose=verbose) 

最终的策略梯度方法是 A2C。

5. 优势演员评论家(A2C)方法

在上一节的演员评论家方法中,目标是使的值函数正确评估状态值。 还有其他用于训练值网络的技术。 一种明显的方法是在值函数优化中使用均方误差MSE),类似于 Q 学习中的算法。 新值梯度等于返回值R[t]与状态值之间的 MSE 偏导数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8foHtRy4-1681704311687)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_127.png)] (Equation 10.5.1)

作为(R[t] - V(s, θ[v])) -> 0,值网络预测在预测给定状态的收益时变得更加准确。 我们将演员评论家算法的这种变化称为“优势演员评论家(A2C)”。 A2C 是[3]提出的“异步优势参与者关键(A3C)”的单线程或同步版本。 数量R[t] - V(s, θ[v])被称为优势

“算法 10.5.1”总结了 A2C 方法。 A2C 和演员评论家之间存在一些差异。演员评论家在线上或根据经验样本进行训练。 A2C 类似于带基线的蒙特卡洛算法,REINFORCE 和 REINFORCE。 一集完成后,将对其进行训练。 从第一个状态到最后一个状态都对演员评论家进行了训练。 A2C 训练从最后一个状态开始,并在第一个状态结束。 此外,γ^t不再打折 A2C 策略和值梯度。

A2C 的相应网络类似于“图 10.4.1”,因为我们仅更改了梯度计算方法。 为了鼓励训练过程中的探员探索,A3C 算法[3]建议将策略函数的加权熵值的梯度添加到到梯度函数β ᐁ[θ] H(π(a[t] | s[t], θ))中。 回想一下,熵是对信息或事件不确定性的度量。

算法 10.5.1 优势演员评论家(A2C)

要求:可微分的参数化目标策略网络π(a[t] | s[t], θ)

要求:可微分的参数化值网络V(s[t], θ[v])

要求:折扣因子γ ∈ [0, 1],表现梯度的学习率α,值梯度的学习率α[v]和熵权β

要求θ[0],初始策略网络参数(例如,θ[0] -> 0)。 θ[v0],初始值网络参数(例如θ[v0] -> 0)。

  1. 重复。
  2. 通过跟随π(a[t] | s[t], θ)来生成剧集(s[0]a[0]r[1]s[1], s[1]a[1]r[2]s[2], ..., a[T-1]a[T-1]r[T]s[T])
  3. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CLLHaOOX-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_144.png)]
  4. 对于步骤t = 0, ..., T - 1,执行:
  5. 计算返回值: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pp9FSndT-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_146.png)]
  6. 计算值梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DmtWxs3c-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_147.png)]
  7. 累积梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Od7LJ3eu-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_092.png)]
  8. 计算表现梯度: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ks5H3AF3-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_149.png)]
  9. 执行梯度上升: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AEqrkTvV-1681704311688)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_064.png)]

“列表 10.5.1”显示了A2CAgent类,该类在tf.keras中实现了“算法 10.5.1”。 与两个 REINFORCE 方法不同,返回值是从最后一个体验单元或状态到第一个体验单元或状态的计算得出的。 在每个体验单元,通过调用各自模型的fit()方法,优化第 7 行和第 9 行中用于目标和值函数logp_modelvalue_model的网络。 注意,在对象实例化期间,熵损失的beta或权重设置为0.9,以指示将使用熵损失函数。 此外,使用 MSE 损失函数训练value_model

“列表 10.5.1”:policygradient-car-10.1.1.py

代码语言:javascript复制
class A2CAgent(PolicyAgent):
    def __init__(self, env):
        """Implements the models and training of 
           A2C policy gradient method
        Arguments:
            env (Object): OpenAI gym environment
        """
        super().__init__(env)
        # beta of entropy used in A2C
        self.beta = 0.9
        # loss function of A2C value_model is mse
        self.loss = 'mse' 
代码语言:javascript复制
 def train_by_episode(self, last_value=0):
        """Train by episode 
           Prepare the dataset before the step by step training
        Arguments:
            last_value (float): previous prediction of value net
        """
        # implements A2C training from the last state
        # to the first state
        # discount factor
        gamma = 0.95
        r = last_value
        # the memory is visited in reverse as shown
        # in Algorithm 10.5.1
        for item in self.memory[::-1]:
            [step, state, next_state, reward, done] = item
            # compute the return
            r = reward   gamma*r
            item = [step, state, next_state, r, done]
            # train per step
            # a2c reward has been discounted
            self.train(item) 
代码语言:javascript复制
 def train(self, item, gamma=1.0):
        """Main routine for training 
        Arguments:
            item (list) : one experience unit
            gamma (float) : discount factor [0,1]
        """
        [step, state, next_state, reward, done] = item 
代码语言:javascript复制
 # must save state for entropy computation
        self.state = state 
代码语言:javascript复制
 discount_factor = gamma**step 
代码语言:javascript复制
 # a2c: delta = discounted_reward - value
        delta = reward - self.value(state)[0] 
代码语言:javascript复制
 verbose = 1 if done else 0 
代码语言:javascript复制
 # train the logp model (implies training of actor model
        # as well) since they share exactly the same set of
        # parameters
        self.logp_model.fit(np.array(state),
                            discounted_delta,
                            batch_size=1,
                            epochs=1,
                            verbose=verbose) 
代码语言:javascript复制
 # in A2C, the target value is the return (reward
        # replaced by return in the train_by_episode function)
        discounted_delta = reward
        discounted_delta = np.reshape(discounted_delta, [-1, 1]) 
代码语言:javascript复制
 # train the value network (critic)
        self.value_model.fit(np.array(state),
                             discounted_delta,
                             batch_size=1,
                             epochs=1,
                             verbose=verbose) 

在介绍的四种算法中,它们仅在目标函数和值(如果适用)优化方面有所不同。 在下一节中,我们将介绍四种算法的统一代码。

6. 使用 Keras 的策略梯度方法

上一节中讨论的策略梯度方法(“算法 10.2.1”至“算法 10.5.1”)使用相同的策略和值网络模型。“图 10.2.1”至“图 10.4.1”中的策略和值网络具有相同的配置。 四种策略梯度方法的不同之处仅在于:

  • 表现和值梯度公式
  • 训练策略

在本节中,我们将以一个代码讨论tf.keras算法 10.2.1 至“算法 10.5.1”的通用例程在tf.keras中的实现。

完整的代码可以在这个页面中找到。

但是在讨论实现之前,让我们简要探讨训练环境。

与 Q 学习不同,策略梯度方法适用于离散和连续动作空间。 在我们的示例中,我们将在连续动作空间案例示例中演示四种策略梯度方法,例如 OpenAI 健身房的MountainCarContinuous-v0。 如果您不熟悉 OpenAI Gym,请参阅“第 9 章”,“深度强化学习”。

“图 10.6.1”中显示了MountainCarContinuous-v0二维环境的快照。在此二维环境中,一辆功率不太强的汽车停在两座山之间:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NVQ3m8db-1681704311689)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_04.png)]

图 10.6.1:MountainCarContinuous-v0 OpenAI Gym 环境

为了到达右侧山顶的黄旗,它必须来回行驶以获得足够的动力。 应用于汽车的能量越多(即动作的绝对值越大),则奖励越小(或负作用越大)。

奖励始终为负,到达标志时仅为正。 在这种情况下,汽车将获得 100 的奖励。 但是,每个操作都会受到以下代码的惩罚:

代码语言:javascript复制
reward-= math.pow(action[0],2)*0.1 

有效动作值的连续范围是[-1.0, 1.0]。 超出范围时,动作将被剪裁为其最小值或最大值。 因此,应用大于 1.0 或小于 -1.0 的操作值是没有意义的。

MountainCarContinuous-v0环境状态包含两个元素:

  • 车厢位置
  • 车速

通过编码器将状态转换为状态特征。 像动作空间一样,状态空间也是连续的。 预测的动作是给定状态的策略模型的输出。 值函数的输出是状态的预测值。

如图“图 10.2.1”到“图 10.4.1”所示,在建立策略和值网络之前,我们必须首先创建一个将状态转换为特征的函数。 该函数由自编码器的编码器实现,类似于在“第 3 章”,“自编码器”中实现的编码器。

“图 10.6.2”显示了包括编码器和解码器的自编码器:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sZUVJJBz-1681704311689)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_05.png)]

图 10.6.2:自编码器模型

在“图 10.6.3”中,编码器是由Input(2)-Dense(256, activation='relu')-Dense(128, activation='relu')-Dense(32)制成的 MLP。 每个状态都转换为 32 维特征向量:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8UoRlEjm-1681704311689)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_06.png)]

图 10.6.3:编码器模型

在“图 10.6.4”中,解码器也是 MLP,但由Input(32)-Dense(128, activation='relu')-Dense(256, activation='relu')-Dense(2)制成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ny7SQcGS-1681704311689)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_07.png)]

图 10.6.4:解码器模型

自编码器使用 MSE,损失函数和tf.keras默认的 Adam 优化器训练了 10 个周期。 我们为训练和测试数据集采样了 220,000 个随机状态,并应用了 200,000:20,000 个训练测试拆分。 训练后,将保存编码器权重,以备将来在策略和值网络的训练中使用。“列表 10.6.1”显示了构建和训练自编码器的方法。

tf.keras实现中,除非另有说明,否则我们将在本节中提及的所有例程均作为PolicyAgent类中的方法实现。 PolicyAgent的作用是代表策略梯度方法的常用功能,包括建立和训练自编码器网络模型以及预测动作,对数概率,熵和状态值。 这是“列表 10.2.1”至“列表 10.5.1”中介绍的每个策略梯度方法智能体类的超类。

“列表 10.6.1”:policygradient-car-10.1.1.py

构建和训练特征自编码器的方法:

代码语言:javascript复制
class PolicyAgent:
    def __init__(self, env):
        """Implements the models and training of 
            Policy Gradient Methods
        Argument:
            env (Object): OpenAI gym environment
        """ 
代码语言:javascript复制
 self.env = env
        # entropy loss weight
        self.beta = 0.0
        # value loss for all policy gradients except A2C
        self.loss = self.value_loss 
代码语言:javascript复制
 # s,a,r,s' are stored in memory
        self.memory = [] 
代码语言:javascript复制
 # for computation of input size
        self.state = env.reset()
        self.state_dim = env.observation_space.shape[0]
        self.state = np.reshape(self.state, [1, self.state_dim])
        self.build_autoencoder() 
代码语言:javascript复制
 def build_autoencoder(self):
        """autoencoder to convert states into features
        """
        # first build the encoder model
        inputs = Input(shape=(self.state_dim, ), name='state')
        feature_size = 32
        x = Dense(256, activation='relu')(inputs)
        x = Dense(128, activation='relu')(x)
        feature = Dense(feature_size, name='feature_vector')(x) 
代码语言:javascript复制
 # instantiate encoder model
        self.encoder = Model(inputs, feature, name='encoder')
        self.encoder.summary()
        plot_model(self.encoder,
                   to_file='encoder.png',
                   show_shapes=True) 
代码语言:javascript复制
 # build the decoder model
        feature_inputs = Input(shape=(feature_size,),
                               name='decoder_input')
        x = Dense(128, activation='relu')(feature_inputs)
        x = Dense(256, activation='relu')(x)
        outputs = Dense(self.state_dim, activation='linear')(x) 
代码语言:javascript复制
 # instantiate decoder model
        self.decoder = Model(feature_inputs,
                             outputs,
                             name='decoder')
        self.decoder.summary()
        plot_model(self.decoder,
                   to_file='decoder.png',
                   show_shapes=True) 
代码语言:javascript复制
 # autoencoder = encoder   decoder
        # instantiate autoencoder model
        self.autoencoder = Model(inputs,
                                 self.decoder(self.encoder(inputs)),
                                 name='autoencoder')
        self.autoencoder.summary()
        plot_model(self.autoencoder,
                   to_file='autoencoder.png',
                   show_shapes=True) 
代码语言:javascript复制
 # Mean Square Error (MSE) loss function, Adam optimizer
        self.autoencoder.compile(loss='mse', optimizer='adam') 
代码语言:javascript复制
 def train_autoencoder(self, x_train, x_test):
        """Training the autoencoder using randomly sampled
            states from the environment
        Arguments:
            x_train (tensor): autoencoder train dataset
            x_test (tensor): autoencoder test dataset
        """ 
代码语言:javascript复制
 # train the autoencoder
        batch_size = 32
        self.autoencoder.fit(x_train,
                             x_train,
                             validation_data=(x_test, x_test),
                             epochs=10,
                             batch_size=batch_size) 

在给定MountainCarContinuous-v0环境的情况下,策略(或参与者)模型会预测必须应用于汽车的操作。 如本章第一部分中有关策略梯度方法的讨论所述,对于连续动作空间,策略模型从高斯分布π(a[t] | s[t], θ) = a[t] ~ N(μ(s[t]), σ²(s[t]))中采样一个动作。 在tf. keras中,实现为:

代码语言:javascript复制
import tensorflow_probability as tfp
    def action(self, args):
        """Given mean and stddev, sample an action, clip 
            and return
            We assume Gaussian distribution of probability 
            of selecting an action given a state
        Arguments:
            args (list) : mean, stddev list
        """
        mean, stddev = args
        dist = tfp.distributions.Normal(loc=mean, scale=stddev)
        action = dist.sample(1)
        action = K.clip(action,
                        self.env.action_space.low[0],
                        self.env.action_space.high[0])
        return action 

动作被限制在其最小和最大可能值之间。 在这种方法中,我们使用TensorFlow probability包。 可以通过以下方式单独安装:

代码语言:javascript复制
pip3 install --upgrade tensorflow-probability 

策略网络的作用是预测高斯分布的均值和标准差。“图 10.6.5”显示了为π(a[t] | s[t], θ)建模的策略网络。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AXHFTr1l-1681704311690)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_08.png)]

图 10.6.5:策略模型(参与者模型)

请注意,编码器模型具有冻结的预训练权重。 仅平均值和标准差权重会收到表现梯度更新。 策略网络基本上是“公式 10.1.4”和“公式 10.1.5”的实现,为方便起见在此重复:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1IWoPJZK-1681704311690)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_153.png)] (Equation 10.1.4)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2fNGss3B-1681704311690)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_036.png)] (Equation 10.1.5)

其中φ(s[t])是编码器,θ[μ]是平均值Dense(1)层的权重,θ[σ]是标准差Dense(1)层的权重。 我们使用修改后的softplus函数ζ(·)来避免标准差为零:

代码语言:javascript复制
def softplusk(x):
    """Some implementations use a modified softplus 
        to ensure that the stddev is never zero
    Argument:
        x (tensor): activation input
    """
    return K.softplus(x)   1e-10 

策略模型构建器显示在“列表 10.6.2”中。 对数概率,熵和值模型也包含在此清单中,我们将在下面讨论。

“列表 10.6.2”:policygradient-car-10.1.1.py

根据编码后的状态特征构建策略(角色),logp,熵和值模型的方法:

代码语言:javascript复制
 def build_actor_critic(self):
        """4 models are built but 3 models share the
            same parameters. hence training one, trains the rest.
            The 3 models that share the same parameters 
                are action, logp, and entropy models. 
            Entropy model is used by A2C only.
            Each model has the same MLP structure:
            Input(2)-Encoder-Output(1).
            The output activation depends on the nature 
                of the output.
        """
        inputs = Input(shape=(self.state_dim, ), name='state')
        self.encoder.trainable = False
        x = self.encoder(inputs)
        mean = Dense(1,
                     activation='linear',
                     kernel_initializer='zero',
                     name='mean')(x)
        stddev = Dense(1,
                       kernel_initializer='zero',
                       name='stddev')(x)
        # use of softplusk avoids stddev = 0
        stddev = Activation('softplusk', name='softplus')(stddev)
        action = Lambda(self.action,
                        output_shape=(1,),
                        name='action')([mean, stddev])
        self.actor_model = Model(inputs, action, name='action')
        self.actor_model.summary()
        plot_model(self.actor_model,
                   to_file='actor_model.png',
                   show_shapes=True) 
代码语言:javascript复制
 logp = Lambda(self.logp,
                      output_shape=(1,),
                      name='logp')([mean, stddev, action])
        self.logp_model = Model(inputs, logp, name='logp')
        self.logp_model.summary()
        plot_model(self.logp_model,
                   to_file='logp_model.png',
                   show_shapes=True) 
代码语言:javascript复制
 entropy = Lambda(self.entropy,
                         output_shape=(1,),
                         name='entropy')([mean, stddev])
        self.entropy_model = Model(inputs, entropy, name='entropy')
        self.entropy_model.summary()
        plot_model(self.entropy_model,
                   to_file='entropy_model.png',
                   show_shapes=True) 
代码语言:javascript复制
 value = Dense(1,
                      activation='linear',
                      kernel_initializer='zero',
                      name='value')(x)
        self.value_model = Model(inputs, value, name='value')
        self.value_model.summary()
        plot_model(self.value_model,
                   to_file='value_model.png',
                   show_shapes=True) 
代码语言:javascript复制
 # logp loss of policy network
        loss = self.logp_loss(self.get_entropy(self.state),
                              beta=self.beta)
        optimizer = RMSprop(lr=1e-3)
        self.logp_model.compile(loss=loss, optimizer=optimizer) 
代码语言:javascript复制
 optimizer = Adam(lr=1e-3)
        self.value_model.compile(loss=self.loss, optimizer=optimizer) 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iwHwwjug-1681704311690)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_09.png)]

图 10.6.6:策略的高斯对数概率模型

除了策略网络π(a[t] | s[t], θ)之外,我们还必须具有操作日志概率(logp)网络ln π(a[t] | s[t], θ),因为该实际上是计算梯度的系统。 如图“图 10.6.6”所示,logp网络只是一个策略网络,其中附加的 Lambda(1)层在给定了作用,均值和标准差的情况下计算了高斯分布的对数概率。

logp网络和参与者(策略)模型共享同一组参数。 Lambda 层没有任何参数。 它是通过以下函数实现的:

代码语言:javascript复制
 def logp(self, args):
        """Given mean, stddev, and action compute
            the log probability of the Gaussian distribution
        Arguments:
            args (list) : mean, stddev action, list
        """
        mean, stddev, action = args
        dist = tfp.distributions.Normal(loc=mean, scale=stddev)
        logp = dist.log_prob(action)
        return logp 

训练logp网络也可以训练角色模型。 在本节中讨论的训练方法中,仅训练logp网络。

如图“图 10.6.7”所示,熵模型还与策略网络共享参数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3Pha36cK-1681704311690)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_10.png)]

图 10.6.7:熵模型

给定平均值和标准差,使用以下函数,输出Lambda(1)层计算高斯分布的熵:

代码语言:javascript复制
 def entropy(self, args):
        """Given the mean and stddev compute 
            the Gaussian dist entropy
        Arguments:
            args (list) : mean, stddev list
        """
        mean, stddev = args
        dist = tfp.distributions.Normal(loc=mean, scale=stddev)
        entropy = dist.entropy()
        return entropy 

熵模型仅用于 A2C 方法。

“图 10.6.8”显示了值模型:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1Svlkq35-1681704311691)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_11.png)]

图 10.6.8:值模型

该模型还使用具有权重的预训练编码器来实现以下公式“公式 10.3.2”,为方便起见,在此重复:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BWxI2Lq0-1681704311691)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_161.png)] (Equation 10.3.2)

θ[v]Dense(1)层的权重,该层是唯一接收值梯度更新的层。“图 10.6.8”表示“算法 10.3.1”至“算法 10.5.1”中的V(s[t], θ[v])。 值模型可以建立在以下几行中:

代码语言:javascript复制
inputs = Input(shape=(self.state_dim, ), name='state')
self.encoder.trainable = False
x = self.encoder(inputs)
value = Dense(1,
              activation='linear',
              kernel_initializer='zero',
              name='value')(x)
self.value_model = Model(inputs, value, name='value') 

这些行也用build_actor_critic()方法实现,如清单 10.6.2 所示。

建立网络模型后,下一步就是训练。 在“算法 10.2.1”至“算法 10.5.1”中,我们通过梯度上升执行目标函数最大化。 在tf.keras中,我们通过梯度下降执行损失函数最小化。 损失函数只是目标函数最大化的负数。 梯度下降是梯度上升的负值。“列表 10.6.3”显示了logp和值损失函数。

我们可以利用损失函数的通用结构来统一“算法 10.2.1”至“算法 10.5.1”中的损失函数。 表现和值梯度仅在其恒定因子上有所不同。 所有表现梯度都有一个通用项ᐁ[θ] ln π(a[t] | s[t], θ)。 这由策略日志概率损失函数logp_loss()中的y_pred表示。 通用项ᐁ[θ] ln π(a[t] | s[t], θ)的因素取决于哪种算法,并实现为y_true。“表 10.6.1”显示y_true的值。 其余项是熵的加权梯度β ᐁ[θ] H(π(a[t] | s[t], θ))。 这是通过logp_loss()函数中betaentropy的乘积实现的。 仅 A2C 使用此项,因此默认为self.beta=0.0。 对于 A2C,self.beta=0.9

“列表 10.6.3”:policygradient-car-10.1.1.py

logp和值网络的损失函数:

代码语言:javascript复制
 def logp_loss(self, entropy, beta=0.0):
        """logp loss, the 3rd and 4th variables 
            (entropy and beta) are needed by A2C 
            so we have a different loss function structure
        Arguments:
            entropy (tensor): Entropy loss
            beta (float): Entropy loss weight
        """
        def loss(y_true, y_pred):
            return -K.mean((y_pred * y_true) 
                      (beta * entropy), axis=-1) 
代码语言:javascript复制
 return loss 
代码语言:javascript复制
 def value_loss(self, y_true, y_pred):
        """Typical loss function structure that accepts 
            2 arguments only
           this will be used by value loss of all methods 
            except A2C
        Arguments:
            y_true (tensor): value ground truth
            y_pred (tensor): value prediction
        """
        return -K.mean(y_pred * y_true, axis=-1) 

算法

logp_loss的y_true

value_loss的y_true

10.2.1 REINFORCE

γ^t R[t]

不适用

10.3.1 使用基线的 REINFORCE

γ^t δ

γ^t δ

10.4.1 演员评论家

γ^t δ

γ^t δ

10.5.1 A2C

R[t] - V(s, θ[v])

R[t]

表 10.6.1:logp_lossy_true值和value_loss

“表 10.6.2”中显示了用于计算“表 10.6.1”中的y_true的代码实现:

算法

y_true公式

Keras 中的y_true

10.2.1 REINFORCE

γ^t R[t]

reward * discount_factor

10.3.1 使用基线的 REINFORCE

γ^t δ

(reward - self.value(state)[0]) * discount_factor

10.4.1 演员评论家

γ^t δ

(reward - self.value(state)[0] gamma * next_value) * discount_factor

10.5.1 A2C

R[t] - V(s, θ[v])和R[t]

(reward - self.value(state)[0])和reward

表 10.6.2:表 10.6.1 中的y_true

类似地,“算法 10.3.1”和“算法 10.4.1”的值损失函数具有相同的结构。 值损失函数在tf.keras中实现为value_loss(),如“列表 10.6.3”所示。 公共梯度因子ᐁ[θ[v]] V(s[t], θ[v])由张量y_pred表示。 剩余因子由y_true表示。 y_true值也显示在“表 10.6.1”中。 REINFORCE 不使用值函数。 A2C 使用 MSE 损失函数来学习值函数。 在 A2C 中,y_true代表目标值或基本情况。

有了所有网络模型和损失函数,最后一部分是训练策略,每种算法都不同。 每个策略梯度方法的训练算法已在“列表 10.2.1”至“列表 10.5.1”中进行了讨论。 “算法 10.2.1”,“算法 10.3.1”和“算法 10.5.1”等待完整的剧集在训练之前完成,因此它同时运行train_by_episode()train()。 完整剧集保存在self.memory中。 演员评论家“算法 10.4.1”每步训练一次,仅运行train()

“列表 10.6.4”显示了当智能体执行并训练策略和值模型时,一个剧集如何展开。 for循环执行 1,000 集。 当达到 1,000 步或汽车触及旗帜时,剧集终止。 智能体在每个步骤执行策略预测的操作。 在每个剧集或步骤之后,将调用训练例程。

“列表 10.6.4”:policygradient-car-10.1.1.py

代码语言:javascript复制
 # sampling and fitting
    for episode in range(episode_count):
        state = env.reset()
        # state is car [position, speed]
        state = np.reshape(state, [1, state_dim])
        # reset all variables and memory before the start of
        # every episode
        step = 0
        total_reward = 0
        done = False
        agent.reset_memory()
        while not done:
            # [min, max] action = [-1.0, 1.0]
            # for baseline, random choice of action will not move
            # the car pass the flag pole
            if args.random:
                action = env.action_space.sample()
            else:
                action = agent.act(state)
            env.render()
            # after executing the action, get s', r, done
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, state_dim])
            # save the experience unit in memory for training
            # Actor-Critic does not need this but we keep it anyway.
            item = [step, state, next_state, reward, done]
            agent.remember(item) 
代码语言:javascript复制
 if args.actor_critic and train:
                # only actor-critic performs online training
                # train at every step as it happens
                agent.train(item, gamma=0.99)
            elif not args.random and done and train:
                # for REINFORCE, REINFORCE with baseline, and A2C
                # we wait for the completion of the episode before 
                # training the network(s)
                # last value as used by A2C
                if args.a2c:
                    v = 0 if reward > 0 else agent.value(next_state)[0]
                    agent.train_by_episode(last_value=v)
                else:
                    agent.train_by_episode() 
代码语言:javascript复制
 # accumulate reward
            total_reward  = reward
            # next state is the new state
            state = next_state
            step  = 1 

在训练期间,我们收集了数据以确定每个策略梯度算法的表现。 在下一部分中,我们总结了结果。

7. 策略梯度方法的表现评估

通过训练智能体 1000 次剧集,评估了 4 种策略梯度方法。 我们将 1 次训练定义为 1,000 次训练。 第一表现度量标准是通过累计汽车在 1,000 集内达到标志的次数来衡量的。

在此指标中,A2C 达到该标志的次数最多,其次是 REINFORCE(具有基线,演员评论家和 REINFORCE)。 使用基线或批判者可以加速学习。 请注意,这些是训练会话,智能体会在其中不断提高其表现。 在实验中,有些情况下智能体的表现没有随时间改善。

第二个表现指标基于以下要求:如果每集的总奖励至少为 90.0,则认为MountainCarContinuous-v0已解决。 从每种方法的 5 个训练会话中,我们选择了最近 100 个剧集(第 900 至 999 集)中最高总奖励的 1 个训练会话。

“图 10.7.1”至“图 10.7.4”显示了在执行 1000 集时山地车到达标志的次数。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MGmtbabO-1681704311691)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_12.png)]

图 10.7.1:山车使用 REINFORCE 方法到达标志的次数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2wkUzdYi-1681704311691)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_13.png)]

图 10.7.2:使用基线方法使用 REINFORCE,山地车到达标志的次数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DK0rQAdO-1681704311691)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_14.png)]

图 10.7.3:使用演员评论家方法山地车到达旗帜的次数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-q23jTJPx-1681704311692)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_15.png)]

图 10.7.4:山地车使用 A2C 方法到达标志的次数

“图 10.7.5”至“图 10.7.8”显示 1,000 集的总奖励。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vX9jW3k6-1681704311692)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_16.png)]

图 10.7.5:使用 REINFORCE 方法获得的每集总奖励

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JS7kPuIl-1681704311692)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_17.png)]

图 10.7.6:使用带有基线方法的 REINFORCE,每集获得的总奖励。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lvVyIjBL-1681704311692)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_18.png)]

图 10.7.7:使用演员评论家方法获得的每集总奖励

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-t2gX3Y4m-1681704311693)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_10_19.png)]

图 10.7.8:使用 A2C 方法获得的每集总奖励

以为基准的 REINFORCE 是唯一能够在 1,000 次训练中始终获得约 90 的总奖励的方法。 A2C 的表现仅次于第二名,但无法始终达到至少 90 分的总奖励。

在进行的实验中,我们使用相同的学习率1e-3进行对数概率和值网络优化。 折扣系数设置为 0.99(A2C 除外),以 0.95 的折扣系数更容易训练。

鼓励阅读器通过执行以下操作来运行受过训练的网络:

代码语言:javascript复制
python3 policygradient-car-10.1.1.py
--encoder_weights=encoder_weights.h5 --actor_weights=actor_weights.h5 

“表 10.7.1”显示了其他运行policygradient-car-10.1.1.py的模式。 权重文件(即*.h5)可以替换为您自己的预训练权重文件。 请查阅代码以查看其他可能的选项。

目的

运行

从零开始训练 REINFORCE

python3 policygradient-car-10.1.1.py

从头开始使用基线训练 REINFORCE

python3 policygradient-car-10.1.1.py -b

从零开始训练演员评论家

python3 policygradient-car-10.1.1.py -a

从头开始训练 A2C

python3 policygradient-car-10.1.1.py -c

从先前保存的权重中训练 REINFORCE

python3 policygradient-car-10.1.1.py``--encoder-weights=encoder_weights.h5``--actor-weights=actor_weights.h5 --train

使用先前保存的权重使用基线训练 REINFORCE

python3 policygradient-car-10.1.1.py``--encoder-weights=encoder_weights.h5``--actor-weights=actor_weights.h5``--value-weights=value_weights.h5 -b --train

使用先前保存的权重训练演员评论家

python3 policygradient-car-10.1.1.py``--encoder-weights=encoder_weights.h5``--actor-weights=actor_weights.h5``--value-weights=value_weights.h5 -a --train

使用先前保存的权重训练 A2C

python3 policygradient-car-10.1.1.py``--encoder-weights=encoder_weights.h5``--actor-weights=actor_weights.h5``--value-weights=value_weights.h5 -c --train

表 10.7.1:运行 policygradient-car-10.1.1.py 时的不同选项

最后一点,我们在tf.keras中对策略梯度方法的实现存在一些局限性。 例如,训练演员模型需要对动作进行重新采样。 首先对动作进行采样并将其应用于环境,以观察奖励和下一个状态。 然后,采取另一个样本来训练对数概率模型。 第二个样本不一定与第一个样本相同,但是用于训练的奖励来自第一个采样动作,这可能会在梯度计算中引入随机误差。

8. 总结

在本章中,我们介绍了策略梯度方法。 从策略梯度定理开始,我们制定了四种方法来训练策略网络。 详细讨论了四种方法:REINFORCE,带有基线的 REINFORCE,演员评论家和 A2C 算法。 我们探讨了如何在 Keras 中实现这四种方法。 然后,我们通过检查智能体成功达到目标的次数以及每集获得的总奖励来验证算法。

与上一章中讨论的深度 Q 网络[2]相似,基本策略梯度算法可以进行一些改进。 例如,最突出的一个是 A3C [3],它是 A2C 的多线程版本。 这使智能体可以同时接触不同的经验,并异步优化策略和值网络。 但是,在 OpenAI 进行的实验中,与 A2C 相比,A3C 没有强大的优势,因为前者无法利用当今提供强大的 GPU 的优势。

在接下来的两章中,我们将着手于另一个领域-对象检测和语义分割。 对象检测使智能体能够识别和定位给定图像中的对象。 语义分割基于对象类别识别给定图像中的像素区域。

9. 参考

  1. Richard Sutton and Andrew Barto: Reinforcement Learning: An Introduction: http://incompleteideas.net/book/bookdraft2017nov5.pdf (2017)
  2. Volodymyr Mnih et al.: Human-level control through deep reinforcement learning, Nature 518.7540 (2015): 529
  3. Volodymyr Mnih et al.: Asynchronous Methods for Deep Reinforcement Learning, International conference on machine learning, 2016
  4. Ronald Williams: Simple statistical gradient-following algorithms for connectionist reinforcement learning, Machine learning 8.3-4 (1992): 229-256

1 人点赞