图像重构是计算机视觉领域里一种经典的图像处理技术,而自编码器算法便是实现该技术的核心算法之一。在了解了自编码器的基本原理之后,本节就通过实例讲解如何利用Tensorflow2.X来一步步地搭建出一个自编码器并将其应用于MNIST手写图像数据的重构当中。
01 编译器模块搭建
在本节中,使用MNIST手写数据集来进行自编码器模型的训练。首先需要搭建的是编码器网络,如前面所述,它的作用是使网络中的输入数据不断地降维变成低维度的隐变量。
首先导入相关的第三方库:
代码语言:javascript复制import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
加载MNIST数据集并对其进行预处理:
代码语言:javascript复制(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data() #加载MNIST数据集
x_train = x_train.reshape(-1,784).astype('float32')/255 #训练集图像打平并归一化
x_test = x_test.reshape(-1,784).astype('float32')/255
在这里使用3层全连接层作为编码器的网络结构,即输入维度为784的图像会不断地经过3层网络并降维变成512,256和60。其中每一层网络都使用ReLu作为激活函数并对神经元权重进行正态分布初始化:
代码语言:javascript复制# 编码器网络
Encoder = tf.keras.models.Sequential([
layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(60, kernel_initializer = 'normal', activation = 'relu')
])
02 解码器模块搭建
解码器网络实质上就是对编码器输出的隐变量进行一次次的上采样,最后输出再还原成和原输入数据相同维度的数据。在这里将之前得到的维度为60的数据再依次升维到256,512和784。同样地,对每一层网络的神经元权重进行正态分布初始化,并将最后一层激活函数换成Sigmoid函数以便于将输出转为像素值:
代码语言:javascript复制# 解码器网络
Decoder = tf.keras.models.Sequential([
layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(784, kernel_initializer = 'normal', activation = 'sigmoid')
])
03 自编码器模型
将上述的编码器和解码器进行结合便可得到完整的自编码器模型。整个自编码器网络结构如图1所示。
图1 自编码器网络
为了方便,可以将编码器和解码器代码封装成类,并将传播过程实现在call函数当中:
代码语言:javascript复制class Autoencoder(tf.keras.Model):
def __init__(self):
super(Autoencoder,self).__init__()
self.Encoder = tf.keras.models.Sequential([ #编码器网络
layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(60, kernel_initializer = 'normal', activation = 'relu')
])
self.Decoder = tf.keras.models.Sequential([ #解码器网络
layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
layers.Dense(784, kernel_initializer = 'normal', activation = 'sigmoid')
])
def call(self,input_features,training = None): #前向传播
code = self.Encoder(input_features) #数据编码
reconstructed = self.Decoder(code) #数据解码
return reconstructed
搭建好自编码器的网络模型之后,下一步便是对该网络进行训练。在训练之前,首先需要配置训练过程所需的优化器及损失函数等参数。在这里选择了Adam优化器以及使用经典的二元交叉熵作为模型的损失函数:
代码语言:javascript复制model = Autoencoder()
model.compile(optimizer = 'adam', loss = 'binary_crossentropy')
配置好所需参数之后,可以正式开始训练已搭建好的模型。这里选用测试集的前4000张作为验证集,而其余作为测试集,由于自编码器模型为无监督训练模型,因此这里输入数据的标签等于输入自身:
代码语言:javascript复制model.fit(x_train,x_train, epochs = 10, batch_size = 256, shuffle = True, validation_data = (x
_test[:4000], x_test[:4000]))
输出的结果为:
代码语言:javascript复制Train on 60000 samples, validate on 4000 samples
……….
Epoch 5/10
60000/60000 [==============================] - 11s 184us/sample - loss: 0.0907 - val_loss: 0.0896
Epoch 6/10
60000/60000 [==============================] - 10s 170us/sample - loss: 0.0877 - val_loss: 0.0862
Epoch 7/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0855 - val_loss: 0.0845
Epoch 8/10
60000/60000 [==============================] - 11s 175us/sample - loss: 0.0838 - val_loss: 0.0832
Epoch 9/10
60000/60000 [==============================] - 11s 190us/sample - loss: 0.0823 - val_loss: 0.0821
Epoch 10/10
60000/60000 [==============================] - 11s 182us/sample - loss: 0.0811 - val_loss: 0.0814
在使用训练集训练好模型之后,还需要对其进行进一步的测试。测试集的后6000张图像被应用于测试图像的重构效果,之后再构建可视化函数对其显示:
代码语言:javascript复制#模型测试
decoded_imgs = model.predict(x_test[4000:])
#原图像与重构后的图像对比
plt.figure(figsize = (20, 4))
n = 10
for i in range(n):
ax = plt.subplot(2, n, i 1)
plt.imshow(tf.reshape(x_test[4000 i], [28, 28]))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i 1 n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
图像重构的结果如图2所示,图中分别展示了训练1和10个epoch的模型测试效果:
图2 1,10个epoch训练模型的图像重构效果对比
本文选自水利水电出版社的《深度学习实战:基于TensorFlow2.X的计算机视觉开发应用 》一书,略有修改,经出版社授权刊登于此。