tf2集成的keras非常好用,对一些简单的模型可以快速搭建,下面以经典mnist数据集为例,做一个demo,展示一些常用的方法
1 导入包并查看版本号
代码语言:javascript复制import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)
2 获取数据集并归一化
这里如果不做归一化模型会不收敛,用的sklearn的归一化
这里注意:
fit_transform
指的是训练数据用的归一化,会记录下均值和方差transform
指的是测试集和验证集用训练集保存下来的方差和均值来做归一化- 归一化时候要做除法运算,所以先用
astype(np.float32)
转换成浮点 - 接着归一化的时候需要二维的输入,这里是三维,所以用reshape:x_train: [None, 28, 28] -> [None, 784]
- 归一化完了之后要再变回来,所以再用一个reshape
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
# x = (x - u) / std
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
# x_train: [None, 28, 28] -> [None, 784]
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(
x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(
x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
代码语言:javascript复制(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)
(10000, 28, 28) (10000,)
3 构建模型
- 用
Sequential()
构建模型,有两种构建方法,一种被注释了。 - 由于输入的时候是28x28的图片,所以在输入层需要一个Flatten拉平
- loss使用的是
sparse_categorical_crossentropy
,他可以自动把类别变成one-hot形式的概率分布,如果标签已经是概率分布,那就用categorical_crossentropy
- 优化器还有adam之类的,直接给名字就行,具体见官方api
- metrics还有mes之类的,具体见官方api
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3' # 使用 GPU 3
# tf.keras.models.Sequential()
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
# model = keras.models.Sequential([
# keras.layers.Flatten(input_shape=[28, 28]),
# keras.layers.Dense(300, activation='relu'),
# keras.layers.Dense(100, activation='relu'),
# keras.layers.Dense(10, activation='softmax')
# ])
# relu: y = max(0, x)
# softmax: 将向量变成概率分布. x = [x1, x2, x3],
# y = [e^x1/sum, e^x2/sum, e^x3/sum], sum = e^x1 e^x2 e^x3
# reason for sparse: y->index. y->one_hot->[]
model.compile(loss="sparse_categorical_crossentropy",
optimizer = "sgd",
metrics = ["acc"])
4 训练模型
- 注意2.0和2. 的输出日志有一点不同,2. 版本后默认batchsize是32
- 和sklearn很像,使用fit函数,返回一个history里面有相关历史信息
- callbacks是回调函数,有很多种,这里只举3个例子,剩下的可以看api。使用的时候在fit里面增加一个callbacks参数,并以list的形式传入
Tensorboard
需要一个目录ModelCheckpoint
需要保存的文件目录,后缀名是h5好像也可以说ckpt,h5便于移植caffe或keras。save_best_only
保存最好的模型,不加这个默认保存是最近的一个模型EarlyStopping
提前终止,patience是可以保持多看几步的耐心,具体见api;min_delta是停止的阈值。可以看出来我设置的是30epoch,在20epoch的时候就earlystopping了
# Tensorboard, earlystopping, ModelCheckpoint
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_model.h5")
callbacks = [
keras.callbacks.TensorBoard(logdir),
keras.callbacks.ModelCheckpoint(output_model_file,
save_best_only = True),
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
history = model.fit(x_train_scaled, y_train, epochs=30,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
输出日志:
代码语言:javascript复制Epoch 1/30
1/1719 [..............................] - ETA: 0s - loss: 3.0171 - acc: 0.0000e 00WARNING:tensorflow:From /data1/home/zly/anaconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use tf.profiler.experimental.stop instead.
WARNING:tensorflow:Callbacks method on_train_batch_end is slow compared to the batch time (batch time: 0.0027s vs on_train_batch_end time: 0.0207s). Check your callbacks.
60/1719 [>.............................] - ETA: 4s - loss: 1.2842 - acc: 0.5719
1719/1719 [==============================] - 4s 2ms/step - loss: 0.5355 - acc: 0.8089 - val_loss: 0.4270 - val_acc: 0.8524
Epoch 2/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3914 - acc: 0.8575 - val_loss: 0.3716 - val_acc: 0.8652
Epoch 3/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3516 - acc: 0.8746 - val_loss: 0.3620 - val_acc: 0.8680
Epoch 4/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3261 - acc: 0.8819 - val_loss: 0.3487 - val_acc: 0.8736
Epoch 5/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3078 - acc: 0.8892 - val_loss: 0.3266 - val_acc: 0.8856
Epoch 6/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2926 - acc: 0.8930 - val_loss: 0.3133 - val_acc: 0.8812
Epoch 7/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2790 - acc: 0.8979 - val_loss: 0.3315 - val_acc: 0.8774
Epoch 8/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2669 - acc: 0.9030 - val_loss: 0.3103 - val_acc: 0.8900
Epoch 9/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2563 - acc: 0.9064 - val_loss: 0.3039 - val_acc: 0.8900
Epoch 10/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2461 - acc: 0.9108 - val_loss: 0.3175 - val_acc: 0.8836
Epoch 11/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2366 - acc: 0.9135 - val_loss: 0.3059 - val_acc: 0.8894
Epoch 12/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2276 - acc: 0.9162 - val_loss: 0.3144 - val_acc: 0.8846
Epoch 13/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2209 - acc: 0.9196 - val_loss: 0.3020 - val_acc: 0.8900
Epoch 14/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2113 - acc: 0.9239 - val_loss: 0.3216 - val_acc: 0.8854
Epoch 15/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2043 - acc: 0.9265 - val_loss: 0.2941 - val_acc: 0.8926
Epoch 16/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1967 - acc: 0.9297 - val_loss: 0.3036 - val_acc: 0.8920
Epoch 17/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1892 - acc: 0.9328 - val_loss: 0.3082 - val_acc: 0.8894
Epoch 18/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1835 - acc: 0.9336 - val_loss: 0.2951 - val_acc: 0.8936
Epoch 19/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1771 - acc: 0.9363 - val_loss: 0.3003 - val_acc: 0.8956
Epoch 20/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1704 - acc: 0.9385 - val_loss: 0.3261 - val_acc: 0.8854
接着查看历史信息:
代码语言:javascript复制# 查看历史信息
history.history
输出:
代码语言:javascript复制{'loss': [0.5354679226875305,
0.39141619205474854,
0.3516489565372467,
0.32609444856643677,
0.3078019320964813,
0.2926427125930786,
0.27895301580429077,
0.2669494152069092,
0.2563493847846985,
0.24608077108860016,
0.23657047748565674,
0.2275625765323639,
0.2209150642156601,
0.21127846837043762,
0.20427322387695312,
0.19672366976737976,
0.1892261505126953,
0.1835436224937439,
0.1771387904882431,
0.17038710415363312],
'acc': [0.8089091181755066,
0.8575454354286194,
0.8745636343955994,
0.8818908929824829,
0.889163613319397,
0.8929636478424072,
0.8978727459907532,
0.902999997138977,
0.9064363837242126,
0.9108181595802307,
0.913527250289917,
0.9161636233329773,
0.9196363687515259,
0.923909068107605,
0.9265090823173523,
0.9297090768814087,
0.9327636361122131,
0.9336363673210144,
0.9363454580307007,
0.9385091066360474],
'val_loss': [0.4270329475402832,
0.3716042935848236,
0.3619808852672577,
0.34866154193878174,
0.32663166522979736,
0.31333956122398376,
0.3315422832965851,
0.3103165328502655,
0.3038505017757416,
0.3175022304058075,
0.30592072010040283,
0.3144492208957672,
0.3020140528678894,
0.32157447934150696,
0.2940865457057953,
0.3035639524459839,
0.30824777483940125,
0.29505348205566406,
0.3002834916114807,
0.326121985912323],
'val_acc': [0.852400004863739,
0.8651999831199646,
0.8679999709129333,
0.8736000061035156,
0.8855999708175659,
0.8812000155448914,
0.8773999810218811,
0.8899999856948853,
0.8899999856948853,
0.8835999965667725,
0.8894000053405762,
0.8845999836921692,
0.8899999856948853,
0.8853999972343445,
0.8925999999046326,
0.8920000195503235,
0.8894000053405762,
0.8935999870300293,
0.8956000208854675,
0.8853999972343445]}
5 画出loss、acc图
根据history,先转换成dataframe,再画出图
代码语言:javascript复制def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)# 使用网格
plt.gca().set_ylim(0, 1)# 设置y坐标轴范围
plt.show()
plot_learning_curves(history)
6 测试模型
代码语言:javascript复制model.evaluate(x_test_scaled, y_test)
输出:
代码语言:javascript复制1/313 [..............................] - ETA: 0s - loss: 0.3203 - acc: 0.8750WARNING:tensorflow:Callbacks method on_test_batch_end is slow compared to the batch time (batch time: 0.0010s vs on_test_batch_end time: 0.0038s). Check your callbacks.
313/313 [==============================] - 0s 2ms/step - loss: 0.3541 - acc: 0.8774A: 0s - loss: 0.3632 - acc:
[0.35407549142837524, 0.8773999810218811]
Tensorboard样式: