Keras图像数据预处理范例——Cifar2图片分类

2020-07-20 14:11:57 浏览数 (1)

本文将以Cifar2数据集为范例,介绍Keras对图片数据进行预处理并喂入神经网络模型的方法。

Cifar2数据集为Cifar10数据集的子集,只包括前两种类别airplane和automobile。

训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。

我们将重点介绍Keras中可以对图片进行数据增强的ImageDataGenerator工具和对内存友好的训练方法fit_generator的使用。让我们出发吧!

一,准备数据

1,获取数据

公众号后台回复关键字:Cifar2,可以获得Cifar2数据集下载链接,数据大约10M,解压后约1.5G。

我们准备的Cifar2数据集的文件结构如下所示。

直观感受一下。

2,数据增强

利用keras中的图片数据预处理工具ImageDataGenerator我们可以轻松地对训练集图片数据设置旋转,翻转,缩放等数据增强。

代码语言:javascript复制
from keras.preprocessing.image import ImageDataGenerator

train_dir = 'cifar2_datasets/train'
test_dir = 'cifar2_datasets/test'

# 对训练集数据设置数据增强
train_datagen = ImageDataGenerator(
            rescale = 1./,
            rotation_range=,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest')

# 对测试集数据无需使用数据增强
test_datagen = ImageDataGenerator(rescale=1./)

数据增强相关参数说明:

  • rotation_range是角度值(在 0~180 范围内),表示图像随机旋转的角度范围。
  • width_shift 和 height_shift 是图像在水平或垂直方向上平移的范围(相对于总宽 度或总高度的比例)。
  • shear_range是随机错切变换的角度。
  • zoom_range是图像随机缩放的范围。
  • horizontal_flip 是随机将一半图像水平翻转。如果没有水平不对称的假设(比如真 实世界的图像),这种做法是有意义的。
  • fill_mode是用于填充新创建像素的方法,这些新像素可能来自于旋转或宽度/高度平移。

查看数据增强效果

代码语言:javascript复制
import os
from keras.preprocessing import image
from matplotlib import pyplot as plt 

%matplotlib inline
%config InlineBackend.figure_format = 'png'
fnames = [os.path.join('cifar2_datasets/train/0_airplane', fname) for 
          fname in os.listdir('cifar2_datasets/train/0_airplane')]

# 载入第3张图像
img_path = fnames[]
img = image.load_img(img_path, target_size=(, ))
x = image.img_to_array(img)
plt.figure(,figsize = (,))
plt.subplot(,,)
plt.imshow(image.array_to_img(x))
plt.title('original image')

# 数据增强后的图像
x = x.reshape((,)   x.shape)
i  = 
for batch in train_datagen.flow(x, batch_size=):
    plt.subplot(,,i )
    plt.imshow(image.array_to_img(batch[]))
    plt.title('after augumentation %d'%(i ))
    i = i    
    if i %  == :
        break
plt.show()

3,导入数据

使用ImageDataGenerator的flow_from_directory方法可以从文件夹中导入图片数据,转换成固定尺寸的张量,这个方法将得到一个可以读取图片数据的生成器generator。

代码语言:javascript复制
train_generator = train_datagen.flow_from_directory(
                    train_dir,
                    target_size=(, ),
                    batch_size=,
                    shuffle = True,
                    class_mode='binary')

test_generator = test_datagen.flow_from_directory(
                    test_dir,
                    target_size=(, ),
                    batch_size=,
                    shuffle = False,
                    class_mode='binary')

print(train_generator.class_indices)

二,构建模型

代码语言:javascript复制
from keras import models,layers,optimizers
from  keras import backend as K

K.clear_session()
model = models.Sequential()
model.add(layers.Flatten(input_shape = (,,)))
model.add(layers.Dense(, activation='relu'))
model.add(layers.Dense(, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
            optimizer=optimizers.RMSprop(lr=1e-4),
            metrics=['acc'])

model.summary()

三,训练模型

代码语言:javascript复制
# 计算每轮次需要的步数 
import numpy as np 
train_steps_per_epoch  = np.ceil(/)
test_steps_per_epoch  = np.ceil(/)

# 使用内存友好的fit_generator方法进行训练
history = model.fit_generator(
        train_generator,
        steps_per_epoch = train_steps_per_epoch,
        epochs = ,
        validation_data= test_generator,
        validation_steps=test_steps_per_epoch,
        workers=, # 读取数据的进程数
        use_multiprocessing=False #linux上可使用多进程读取数据
        )

四,评估模型

五,使用模型

代码语言:javascript复制
from sklearn.metrics import roc_auc_score

test_datagen = ImageDataGenerator(rescale=1./)

# 注意,使用模型进行预测时要设置生成器shuffle = False
test_generator = test_datagen.flow_from_directory(
                 test_dir,
                 target_size=(, ),
                 batch_size=,
                 class_mode='binary',
                 shuffle = False)

# 计算auc
y_pred = model.predict_generator(test_generator,steps = len(test_generator))
y_pred = np.reshape(y_pred,(-1,))
y_true = np.concatenate([test_generator[i][] 
                         for i in range(len(test_generator))])
auc = roc_auc_score(y_true,y_pred)
print('test auc:',auc)

六,保存模型

代码语言:javascript复制
model.save('cifar2_model.h5')

0 人点赞