TensorFlow2.0+的API结构梳理

2021-03-18 17:23:30 浏览数 (1)

本文梳理了tf 2.0以上版本的API结构,用于帮助国内的初学者更好更快的了解这个框架,并为检索官方的API文档提供一些关键词。

官方API文档:https://tensorflow.google.cn/api_docs/python/tf?hl=zh-cn

1. 数据类型

tf中的数据类型为张量:tf.Tensor(),可以类比numpy中的np.array()

一些特殊的张量:

  • tf.Variable:变量。用来存储需要被修改、需要被持久化保存的张量,模型的参数一般都是用变量来存储的。
  • tf.constant:常量,定义后值和维度不可改变。
  • tf.sparse.SparseTensor:稀疏张量。

除上述特殊张量外,其余创建方式同numpy类似,示例:

代码语言:javascript复制
t = tf.ones([5,3], dtype=tf.float32)
a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.shape)

与numpy类似,可以对Tensor进行切片、索引;可以对这些Tensor做各种运算,例如:加减乘除、地板除、布尔运算。

2. 架构

  1. 使用tf.data加载数据,高效的数据输入管道也可以极大的减少模型训练时间,管道执行的过程包括:从硬盘中读取数据(Extract)、数据的预处理如数据清洗、格式转换(Transform)、加载到计算设备(Load)
  2. 使用tf.keras构建、训练和验证模型,另外tf.estimator中打包了一些标准的机器学习模型供我们直接使用,当我们不想从头开始训练一个模型时,可以使用TensorFlow Hub模块来进行迁移学习。
  3. 使用tf.distribute.Strategy实现分布式的训练
  4. 使用CheckpointsSavedModel存储模型,前者依赖于创建模型的源代码;而后者与源代码无关,可以用于其他语言编写的模型。

加载数据示例代码:

代码语言:javascript复制
import tensorflow as tf
import multiprocessing
import matplotlib.pyplot as plt

N_CPUS = multiprocessing.cpu_count()
BATCH_SIZE = 32
SEED = 0

def load_and_preproess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [192,192])
    image /= 255.0
    return image

# 1. 构建图片路径
# 其中 all_image_paths = ['图片1路径','图片2路径',...,'图片n路径']
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
# 2. 构建图片数据的数据集
image_ds = path_ds.map(load_and_preproess_image, num_parallel_calls=N_CPUS)
# 3. 构建标签的数据集
label_ds = tf.data.Dataset.from_tensor_slices(all_image_labels)
# 4. 将图片和类标压缩为(图片,标签)对
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
# 5. 可视化数据
plt.figure(figsize=(8,8))
for n,image_label in enumerate(image_label_ds.take(4)):
    plt.subplot(2,2,n 1)
    plt.imshow(image_label[0])
    plt.grid(False)
    plt.xlabel(image_label[1])
# 6. 打乱数据集
image_count = len(all_image_paths)
ds = image_label_ds.shuffle(buffer_size=image_count, seed=SEED)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) # 让训练和每批次数据加载并行

构建和训练模型示例代码

  • 类式构建:
代码语言:javascript复制
from tensorflow.keras import layers
# 创建网络,两种方法二选一
model = tf.keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(32,)), # 全连接层
    layers.Dense(10, activation='softmax')
])
# 或者
# model = tf.keras.Sequential()
# model.add(layers.Dense(64, activation='relu', input_shape=(32,)))
# model.add(layers.Dense(10, activation='softmax'))
# 编译网络
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss='categorical_crossentropy',
             metrics=['accuracy'])
# 网络训练(可以是numpy数据(见官方文档),也可以是Dataset数据)
# verbose=1表示以进度条的形式显示训练信息, 验证集可以直接给也可以设置比例
model.fit(ds, epochs=2, validation_split=0.2, verbose=1)
# 模型评估(可以是numpy数据(见官方文档),也可以是Dataset数据)
model.evaluate(ds, steps=30)
# 预测
result = model.predict(data, batch_size=50)
print(result[0])
  • 函数式构建:
代码语言:javascript复制
inputs = tf.keras.Input(shape=(32,))
# 网络层像函数一样被调用,输出和输入都是张量
x = layers.Dense(64, activation='relu')(inputs)
predictions = layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 编译和训练同上

模型训练的技巧——callbacks的使用

代码语言:javascript复制
callbacks = [
    # 若验证集上的损失“val_loss”连续两个epoch都没有变化,则提前结束训练
    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
    # 使用TensorBoard把训练的记录保存到 "./logs"
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
model.fit(ds, epochs=5, callbacks=callbacks, validation_data=val_dataset)

如果安装的是gpu版本的TensorFlow会自动使用gpu,查看可用的GPU的代码:

代码语言:javascript复制
from tensorflow.python.client import device_lib

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type=='GPU']

print(get_available_gpus())

单机环境下的多GPU训练:

代码语言:javascript复制
strategy = tf.distribute.MirroredStrategy()

# 优化器及模型的构建和编译必须放在scope()中
with strategy.scope():
    model = tf.keras.Sequential([
        layers.Dense(64, activation='relu', input_shape=(32,)), # 全连接层
        layers.Dense(10, activation='softmax')
	])
    model.compile(optimizer=tf.keras.optimizers.SGD(0.2), loss='binary_crossentropy')

模型的保存和恢复示例代码:

代码语言:javascript复制
# 完整模型的保存和读取
model.save('my_model')
model = tf.keras.models.load_model('my_model')
# 模型的权重参数的保存和读取
model.save_weights('my_model.h5', save_format='h5')
model.load_weights('my_model.h5')
# 单独保存模型的结构
json_string = model.to_json()

3. 模块

加载数据tf.data

构建、训练和验证模型tf.keras

  • activations: tf.keras.activations 中包含了当前主流的激活函数,可以直接通过该API进行激活函数的调用。
  • applications: tf.keras.applications 中包含的是已经进行预训练的神经网络模型,可以直接进行预测或者迁移学习。目前该模块中包含了主流的神经网络结构。
  • backend: tf.keras.backend中包含了Keras后台的一些基础API接口,用于实现高阶API或者自己构建神经网络。
  • datasets: tf.keras.datasets 中包含了常用的公开数据训练集,可以直接进行使用,数据集有CIFAR-100、Boston Housing等。
  • layers: tf.keras.layers 中包含了已经定义好的常用的神经网络层。
  • losses: tf.keras.losses 中包含了常用的损失函数,可以根据实际需求直接进行调用。
  • optimizers: tf.keras.optimizers 中包含了主流的优化器,可以直接调用API使用。比如Adm等优化器可以直接调用,然后配置所需要的参数即可。
  • preprocessing: tf.keras.preprocessing 中包含了数据处理的一些方法,分为图片数据处理、语言序列处理、文本数据处理等,比如 NLP 常用的pad_sequences等,在神经网络模型训练前的数据处理上提供了非常强大的功能。
  • regularizers: tf.keras.regularizers 中提供了常用的正则化方法,包括L1、L2等正则化方法。
  • wrappers: tf.keras.wrappers 是一个 Keras 模型的包装器,当需要进行跨框架迁移时,可以使用该API接口提供与其他框架的兼容性。
  • Sequential类:tf.keras.Sequential 可以让我们将神经网络层进行线性组合形成神经网络结构。

兼容模块tf.compat.v1,这个模块里有完整的TensorFlow1.x的API。

参考文献

[1] 侯伦青, 王飞, 邓昕, 史周安. TensorFlow 从零开始学[M]. 电子工业出版社, 2020.

[2] 赵英俊. 走向TensorFlow2.0深度学习应用编程快速入门[M]. 电子工业出版社, 2019.

0 人点赞