前言
Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据,这样可加快你在数据读取,预处理中的速度。除了“快”,还有另外一个优点就是,在多模态学习(比如视频 音频 文案作为特征)中可以将各种形式的特征预处理后统一放在TFRecord中,避免了读取数据时候的麻烦。
1. 制作
以MNIST数据集为例(不论文本、图片、声音,都是先转化成numpy,在转化成TFRecord),在这里下载好之后,还需要像这样预处理一下。下一步就是把每一张图片读成numpy再写入TFRecord了。读成numpy的过程因人而异因项目而异,个人比较喜欢通过手动制作一个索引文件来读取。具体说来就是用一个文本文件,每行存放一个样本的label、图片路径等信息。大概长这样:
代码语言:javascript复制label,file
5,~/data/Mnist/0.png
0,~/data/Mnist/1.png
4,~/data/Mnist/2.png
1,~/data/Mnist/3.png
... ...
提供一下制作索引文件的逻辑:
代码语言:javascript复制# make index file
label_list = open("./Mnist_Label/label.txt").readlines()[0].split(",")
# output to index_file
index_file = "./index_file.csv"
with open(index_file, "w") as f:
head = "label,filename" "n"
f.write(head)
for i in range(len(label_list)):
filename = "./Mnist/" str(i) ".png"
label = label_list[i]
line = label "," filename "n"
f.write(line)
这样做的好处是,可以不用一口气把数据读进内存,对于大数据集任务比较友好。而且在多模态的任务中,通过“索引文件”的方式也能够使多种形式的多个文件的读取更加简洁,灵活。
接下来就是Step 1 : 把文件特征读取成numpy
代码语言:javascript复制import numpy as np
from PIL import image
index_file = "./index_file.csv"
index_list = open(index_file, "r").readlines()[1:] # 读取索引文件,去掉首行
for line in index_list:
label = int(line.split(",")[0]) # 将每行第一个元素读成int,作为label
img = np.array(Image.open(line.rstrip("n").split(",")[1])) # 根据每行中文件名读取文件,并转化为numpy
"""
这张图片转化为numpy之后,在这里将它写入到TFRecord文件里
"""
现在我们有了numpy形式的图片和int形式的label,怎么写入到TFRecord里呢?
代码语言:javascript复制# 首先我们需要将label和img捏在一起
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
})) # example对象对label和img数据进行封装
# 然后把这个封装好的example写入到文件里
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")
writer.write(example.SerializeToString())
writer.close()
这个过程很简单,但是有一个地方需要说明一下。构建example的时候,这个tf.train.Feature()函数可以接收三种数据:
- bytes_list: 可以存储string 和byte两种数据类型。
- float_list: 可以存储float(float32)与double(float64) 两种数据类型。
- int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。
对于只有一个值(比如label)可以用float_list或int64_list,而像图片、视频、embedding这种列表型的数据,通常转化为bytes格式储存。下面把整个过程梳理一遍:
代码语言:javascript复制import numpy as np
from PIL import image
import tensorflow as tf
index_file = "./index_file.csv"
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord") # 打开文件
index_list = open(index_file, "r").readlines()[1:] # 读取索引文件,去掉首行
for line in index_list:
# 获取label和图片的numpy形式
label = int(line.split(",")[0]) # 将每行第一个元素读成int,作为label
img = np.array(Image.open(line.split(",")[1])) # 根据每行中文件名读取文件,并转化为numpy
# 将label和img捏在一起
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
})) # example对象对label和img数据进行封装
# 将构建好的 example 写入到 TFRecord
writer.write(example.SerializeToString())
# 关闭文件
writer.close()
这就是制作TFRecord的流程啦。这里买下一个伏笔,本例中图片的numpy是np.uint8格式的,每个像素点取值在[0, 255]之间。
代码语言:javascript复制print(img.dtype)
# 输出 dtype('uint8')
2. 读取
TFRecord做好了,要怎么读取呢?我们可以通过tf.data来生成一个迭代器,每次调用都返回一个大小为batch_size的batch。
代码语言:javascript复制def read_and_decode(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
"""
每次调用,从TFRecord文件中读取一个大小为batch_size的batch
Args:
filenames: TFRecord文件
batch_size: batch_size大小
num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
perform_shuffle: 是否乱序
Returns:
tensor格式的,一个batch的数据
"""
def _parse_fn(record):
features = {
"label": tf.FixedLenFeature([], tf.int64),
"image": tf.FixedLenFeature([], tf.string),
}
parsed = tf.parse_single_example(record, features)
# image
image = tf.decode_raw(parsed["image"], tf.uint8)
image = tf.reshape(image, [28, 28])
# label
label = tf.cast(parsed["label"], tf.int64)
return {"image": image}, label
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
这个read_and_decode过程基本上是通用的,直接复制过去改一改就行。对于不同的数据,只需要改动_parse_fn函数就可以。这里有一点很重要,就是在_parse_fn函数中,tf.decode_raw的第二个参数(解码格式),必须和保存TFRecord时候的numpy的格式是一样的,否则会报TypeError,我们保存图片时候采用的是np.uint8,这里解码的时候也要用tf.uint8。
接下来我们来试一试把
代码语言:javascript复制batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")
with tf.Session() as sess:
print(sess.run(batch_features["image"][0]))
print(sess.run(batch_labels[0]))
3. 使用
会写会读之后,我们来简单尝试下怎么用吧!假设我们要用简单的DNN预测MNIST的label。
代码语言:javascript复制# 调用 read_and_decode 获取一个 batch 的数据
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")
# input
X = tf.cast(batch_features["image"], tf.float32, name="input_image")
X = tf.reshape(X, [-1, 28*28]) / 255 # 将像素点的值标准化到[0,1]之间
label = tf.one_hot(tf.cast(batch_labels, tf.int32, name="input_label"), depth=10, name="label")
# DNN Layer
deep_inputs = X
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
y = tf.reshape(y_deep, shape=[-1, 10])
pred = tf.nn.softmax(y, name="pred")
# 构建损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=label))
# 构建train_op
train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)
上面就是简单的,通过read_and_decode函数读取数据,并作为DNN模型的输入的例子。下面的代码是输出一下pred,然后训练10个step,在输出新的pred的例子。
代码语言:javascript复制with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(pred))
for i in range(10):
print(sess.run(train_op))
print(sess.run(pred))
这里我们发现,与tf.placeholder不同,如果采用tf.placeholder作为模型的输入,需要在sess.run()的时候手动的设置feed_dict,来喂一个batch的数据;而如果采用TFRecord,每次sess.run()时,根据向前追溯的计算逻辑,都会自动的调用一次read_and_decode获得一个batch的数据,所以就不需要手动feed数据。