【连载12】带你看懂最早的卷积神经网络LeNet-5

2020-03-03 11:03:44 浏览数 (1)

LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展

LeNet-5一共有8层:1个输入层 3个卷积层(C1、C3、C5) 2个下采样层(S2、S4) 1个全连接层(F6) 1个输出层,每层有多个feature map(自动提取的多组特征)。

输入层

采用keras自带的MNIST数据集,输入像素矩阵为28×28的单通道图像数据。

C1卷积层

由6个feature map组成,每个feature map由5×5卷积核生成(feature map中每个神经元与输入层的5×5区域像素相连),考虑每个卷积核的bias,该层需要学习的参数个数为:(5×5 1)×6=156个,神经元连接数为:156×24×24=89856个。

S2下采样层

该层每个feature map一一对应上一层的feature map,由于每个单元的2×2感受野采用不重叠方式移动,所以会产生6个大小为12×12的下采样feature map,如果采用Max Pooling/Mean Pooling,则该层需要学习的参数个数为0个(如果采用非等权下采样——即采样核有权重,则该层需要学习的参数个数为:(2×2 1)×6=30个),神经元连接数为:30×12×12=4320个。

C3卷积层

这层略微复杂,S2神经元与C3是多对多的关系,比如最简单方式:用S2的所有feature map与C3的所有feature map做全连接(也可以对S2抽样几个feature map出来与C3某个feature map连接),这种全连接方式下:6个S2的feature map使用6个独立的5×5卷积核得到C3中1个feature map(生成每个feature map时对应一个bias),C3中共有16个feature map,所以该层需要学习的参数个数为:(5×5×6 1)×16=2416个,神经元连接数为:2416×8×8=154624个。

S4下采样层

同S2,如果采用Max Pooling/Mean Pooling,则该层需要学习的参数个数为0个,神经元连接数为:(2×2 1)×16×4×4=1280个。

C5卷积层

类似C3,用S4的所有feature map与C5的所有feature map做全连接,这种全连接方式下:16个S4的feature map使用16个独立的1×1卷积核得到C5中1个feature map(生成每个feature map时对应一个bias),C5中共有120个feature map,所以该层需要学习的参数个数为:(1×1×16 1)×120=2040个,神经元连接数为:2040个。

F6全连接层

将C5层展开得到4×4×120=1920个节点,并接一个全连接层,考虑bias,该层需要学习的参数和连接个数为:(1920 1)*84=161364个。

输出层

该问题是个10分类问题,所以有10个输出单元,通过softmax做概率归一化,每个分类的输出单元对应84个输入。

Minist(Modified NIST)数据集下使用LeNet-5的训练可视化:

可以看到其实全连接层之前的各层做的就是特征提取的事儿,且比较通用,对于标准化实物(人、车、花等等)可以复用,后面会单独介绍模型的fine-tuning。

LeNet-5代码实践

代码语言:javascript复制
			import copy
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.pyplot import plot,savefig
from keras.datasets import mnist, cifar10
from keras.models import Sequential, Graph
from keras.layers.core import Dense, Dropout, Activation, Flatten, Reshape
from keras.optimizers import SGD, RMSprop
from keras.utils import np_utils
from keras.regularizers import l2
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
import tensorflow as tf
tf.python.control_flow_ops = tf
from PIL import Image
def build_LeNet5():
   model = Sequential()
   model.add(Convolution2D(6, 5, 5, border_mode='valid', input_shape = (28, 28, 1), dim_ordering='tf'))
   model.add(MaxPooling2D(pool_size=(2, 2)))
   model.add(Activation("relu"))
   model.add(Convolution2D(16, 5, 5, border_mode='valid'))
   model.add(MaxPooling2D(pool_size=(2, 2)))
   model.add(Activation("relu"))
   model.add(Convolution2D(120, 1, 1, border_mode='valid'))
   model.add(Flatten())
   model.add(Dense(84))
   model.add(Activation("sigmoid"))
   model.add(Dense(10))
   model.add(Activation('softmax'))
   return model
if __name__=="__main__":
   from keras.utils.visualize_util import plot
   model = build_LeNet5()
   model.summary()
   plot(model, to_file="LeNet-5.png", show_shapes=True)
   (X_train, y_train), (X_test, y_test) = mnist.load_data()
   X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32') / 255
   X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32') / 255
   Y_train = np_utils.to_categorical(y_train, 10)
   Y_test = np_utils.to_categorical(y_test, 10)
   # training
   model.compile(loss='categorical_crossentropy',
             optimizer='adadelta',
             metrics=['accuracy'])
   batch_size = 128
   nb_epoch = 1
   model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
             verbose=1, validation_data=(X_test, Y_test))
   score = model.evaluate(X_test, Y_test, verbose=0)
   print('Test score:', score[0])
   print('Test accuracy:', score[1])
   y_hat = model.predict_classes(X_test)
   test_wrong = [im for im in zip(X_test,y_hat,y_test) if im[1] != im[2]]
   plt.figure(figsize=(10, 10))
   for ind, val in enumerate(test_wrong[:100]):
       plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
       plt.subplot(10, 10, ind   1)
       im = 1 - val[0].reshape((28,28))
       plt.axis("off")
       plt.text(0, 0, val[2], fontsize=14, color='blue')
       plt.text(8, 0, val[1], fontsize=14, color='red')
       plt.imshow(im, cmap='gray')
   savefig('error.jpg')		

网络结构

错误分类可视化

0 人点赞