VAE,即变分自编码器,是常见的生成模型其中一类。常见的生成模型类型还有GAN、flow、DDPM等。
前置知识
1. KL散度
KL散度可以衡量两个分布的相似程度,当KL散度为0时,代表两个分布完全相同。注意KL散度不是距离,因为 KL(p||q)不等于KL(q||p). KL散度的计算公式为:
$$ begin{align} KL(p||q)&=int p(x)logfrac{p(x)}{q(x)}dx \ &=sum_{i=1}^{N}p(x_i)log{frac{p(x_i)}{q(x_i)}} end{align} $$
高斯分布的KL散度(公式推导):
2. 重参数化技巧
由于直接从N(mu,sigma^2)的分布中采样对分布的参数是不可导的,因此先从N(0,1)采样出 z,再得到 sigma z mu,这样采样出来对 sigma 和 mu 就是可导的。
AE-VAE-CVAE
AE,即自动编码器,由编码器和解码器两部分组成,编码器将输入映射成一种“数值编码”,解码器将“数值编码”映射回图像。如果训练时没有将某个图片编码,那么我们就不太可能生成这个图片。因此,AE适合用于数据压缩和恢复,不太适合于数据生成。
VAE不将输入图片映射成“数值编码”,而将其映射为“分布”,VAE可以生成没有见过的图片。以下是AE和VAE的对比图:
VAE的结构图如下:
训练VAE时,损失函数包括两部分:
- 为了让输出和输入尽可能像,所以要让输出和输入的差距尽可能小,此部分用MSELoss来计算,即最小化 MSELoss。
- 训练过程中,如果仅仅使输入和输出的误差尽可能小,那么随着不断训练,会使得sigma趋近于0,这样就使得VAE越来越像AE,对数据产生了过拟合,编码的噪声也会消失,导致无法生成未见过的数据。因此为了解决这个问题,我们要对mu和sigma加以约束,使其构成的正态分布尽可能像标准正态分布,具体做法是计算N(mu,sigma^2)与N(0,1)之间的KL散度。
即 loss = MSE(X, X') KL(N(mu, sigma^2), N(0,1)),代码如下:
代码语言:javascript复制def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss kl_loss
return loss
两部分loss对隐变量z的生成的可视化效果如下图:
注:VAE的缺点是生成的图像不一定那么“真”,如果要使生成的数据“真”,则要用到GAN。
begin{align} p(x)&=int p_{theta}(x|z)p(z) dz \ p(x)&=int q_{phi}(z|x) frac{p_{theta}(x|z)p(z)}{q_{phi}(z|x)} dz \ log p(x) &= log E_{q_phi(z|x)}left[frac{p_{theta}(x|z)p(z)}{q_{phi}(z|x)}right] \ &geq E_{q_phi(z|x)}left[log frac{p_{theta}(x|z)p(z)}{q_{phi}(z|x)}right] end{align}
前面所说的AE适合数据压缩与还原,不适合生成未见过的数据。VAE适合生成未见过的数据,但不能控制生成内容。而CVAE(Conditional VAE)可以在生成数据时通过指定其标签来生成想生成的数据。CVAE的结构图如下所示:
整体结构和VAE差不多,区别是在将数据输入Encoder时把数据内容与其标签(label)合并(cat)一起输入,将编码(Z)输入Decoder时把编码内容与数据标签(label)合并(cat)一起输入。注意label并不参与Loss计算,CVAE的Loss和VAE的Loss计算方式相同。