用于文本生成的GAN模型

2021-10-08 16:20:56 浏览数 (1)

生成对抗网络(GAN)包含两个部分:一个是生成器(generator),一个是判别模型(discriminator)。生成器的任务是生成看起来逼真与原始数据相似的样本。判别器的任务是判断生成模型生成的样本是真实的还是伪造的。换句话说,生成器要生成能骗过判别器的实例,而判别器要从真假混合的样本中揪出由生成器生成的伪造样本。生成器和判别器的训练过程是一个对抗博弈的过程,最后博弈的结果是在最理想的状态下,生成器可以生成足以“以假乱真”的样本。

一、什么是GAN

生成对抗网络(GAN)包含两个部分:一个是生成器(generator),一个是判别模型(discriminator)。生成器的任务是生成看起来逼真与原始数据相似的样本。判别器的任务是判断生成模型生成的样本是真实的还是伪造的。换句话说,生成器要生成能骗过判别器的实例,而判别器要从真假混合的样本中揪出由生成器生成的伪造样本。生成器和判别器的训练过程是一个对抗博弈的过程,最后博弈的结果是在最理想的状态下,生成器可以生成足以“以假乱真”的样本。

图1. GAN的基本结构

二、GAN在文本生成中遇到的困境

传统的GAN只适用于连续型数据的生成,对于离散型数据效果不佳。文本数据不同于图像数据,文本数据是典型的离散型数据。图像数据在计算机中被表示为矩阵,矩阵中的数值可微分并且直接反映出图像本身的属性,从图像矩阵到图像不需要采样;而文本数据在计算机中表示为one-hot编码的向量,这个向量中有n项是0,只有一项是1,这一项代表词库中某个词,我们在神经网络中操作时,最后得到的都是一个某个词向量每个维度的概率分布而非标准的one-hot编码的向量,只能将这个输出结果过渡到one-hot向量再从词库中查找对应的词,这个操作被称为采样。

神经网络的优化方法大多是基于梯度的,GAN在面对离散型数据时,判别器无法把梯度反向传播给生成器——判别器得到的是生成器采样后的结果,在判别器参数微调后,可能输出优化了一点点但还不足以改变采样的结果,例如生成器网络的最后结果为[0.1,0.1,0.8]经过采样输出的one-hot词向量为[0,0,1],而参数微调后生成器网络最后的结果变为了[0.1,0.2,0.7]经过采样输出的one-hot词向量依旧为[0,0,1],生成器便会再一次将相同答案输入给判别器,这样判别器给出的评价就会毫无意义,生成器的训练也会失去方向。

为了解决GAN在面对离散型数据无法将梯度反向传播给生成器的问题,人们提出了三种方案:1.判别器直接获取生成器的输出;2.使用Gumbel-softmax代替softmax;3.通过强化学习来绕过采样带来的问题。其中第一种方法虽然可以绕过采样操作,避免采样带来的梯度无法反传的问题,但生成的数据与真实数据差距太大,判别器可以很轻易地分辨出生成的数据与真实的数据,因为此时生成的数据是离散的向量,而真实数据是one-hot向量,判别器可以很容易分辨两者的差异,此时GAN是难以训练的。

三、几种用于生成文本的GAN模型

3.1 Seq-GAN

SeqGAN的核心思想是将GAN与强化学习的Policy Gradient算法结合到一起,出发点是意识到了标准的GAN在处理离散数据时会遇到的困难:生成器难以梯度更新,判别器难以评估非完整序列。对于生成器难以梯度更新问题,作者把整个GAN看作一个强化学习系统,用Policy Gradient算法更新Generator的参数;对于判别器难以评估非完整序列问题,作者借鉴了蒙特卡洛树搜索的思想,对任意时刻的非完整序列都可以进行评估。

图2. SeqGAN结构

SeqGAN结构如图2所示,已经存在的红色圆点称为现在的状态(state),要生成的下一个红色圆点称作动作(action),因为D需要对一个完整的序列评分,所以就是用MCTS(蒙特卡洛树搜索)将每一个动作的各种可能性补全,D对这些完整的序列产生reward,回传给G,通过增强学习更新G。这样就是用Reinforcement learning的方式,训练出一个可以产生下一个最优的action的生成网络。

3.2 LeakGAN

基于GAN生成文本的方法大多数场景是生成短文本,对于长文本来说还是存在很多挑战。先前的GAN中判别器的标量指导信号是稀疏的,只有在完整生成文本后才可用,缺少生成过程中的文本结构的中间信息。当生成的文本样本长度很长时效果不好。LeakGAN通过泄露判别器提取的特征作为引导信号,指导生成器更好地生成长文本。同时,借助分层强化学习从判别器向生成器提供更丰富的信息。

图3. LeakGAN结构

3.3 RelGAN

RelGAN由三个主要组件组成:基于关系记忆的生成器、Gumbel-Softmax用于离散数据上训练GAN、鉴别器中嵌入多个表示为生成器提供更多信息。在样品质量和多样性方面,RelGAN相比于其他GAN模型具有一定优势。并且,RelGAN可以通过单个可调参数控制样本质量和多样性之间的权衡。

图4. RelGAN生成器中的注意力机制

t时刻的记忆单元Mt和矩阵Wq相乘得到Q矩阵,Mt与t时刻的输入的词向量xt拼接后分别于WK、Wv相乘得到K矩阵和V矩阵,Q矩阵和K矩阵的转置相乘后经过sofmax函数得到注意力权重,再将注意力全中与V矩阵相乘得到更新后的记忆单元。

图5. RelGAN判别器

判别器结构如图5所示,为了从多方面捕获输入特征,词向量通过多个词向量表示层输入CNN网络,这样子就输出多个判别器损失,综合多个方面的判别器损失,得到最终的损失输出,这样子,可以从多个方面综合评估词向量的差异,提供多样性和更加丰富的信息指导判别器的训练。


参考文献

https://arxiv.org/abs/1609.05473

https://arxiv.org/abs/1709.08624v2

https://arxiv.org/abs/1908.07269


0 人点赞