Tensorflow笔记:TFRecord的制作与读取

2021-09-24 15:06:11 浏览数 (1)

前言

Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据,这样可加快你在数据读取,预处理中的速度。除了“快”,还有另外一个优点就是,在多模态学习(比如视频 音频 文案作为特征)中可以将各种形式的特征预处理后统一放在TFRecord中,避免了读取数据时候的麻烦。

1. 制作

以MNIST数据集为例(不论文本、图片、声音,都是先转化成numpy,在转化成TFRecord),在这里下载好之后,还需要像这样预处理一下。下一步就是把每一张图片读成numpy再写入TFRecord了。读成numpy的过程因人而异因项目而异,个人比较喜欢通过手动制作一个索引文件来读取。具体说来就是用一个文本文件,每行存放一个样本的label、图片路径等信息。大概长这样:

代码语言:javascript复制
label,file
5,~/data/Mnist/0.png
0,~/data/Mnist/1.png
4,~/data/Mnist/2.png
1,~/data/Mnist/3.png
... ...

提供一下制作索引文件的逻辑:

代码语言:javascript复制
# make index file
label_list = open("./Mnist_Label/label.txt").readlines()[0].split(",")

# output to index_file
index_file = "./index_file.csv"
with open(index_file, "w") as f:
    head = "label,filename"   "n"
    f.write(head)
    for i in range(len(label_list)):
        filename = "./Mnist/"   str(i)   ".png"
        label = label_list[i]
        line = label   ","   filename   "n"
        f.write(line)

这样做的好处是,可以不用一口气把数据读进内存,对于大数据集任务比较友好。而且在多模态的任务中,通过“索引文件”的方式也能够使多种形式的多个文件的读取更加简洁,灵活。

接下来就是Step 1 : 把文件特征读取成numpy

代码语言:javascript复制
import numpy as np
from PIL import image

index_file = "./index_file.csv"
index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.rstrip("n").split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    """
    这张图片转化为numpy之后,在这里将它写入到TFRecord文件里
    """

现在我们有了numpy形式的图片和int形式的label,怎么写入到TFRecord里呢?

代码语言:javascript复制
# 首先我们需要将label和img捏在一起
example = tf.train.Example(features=tf.train.Features(feature={
    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
}))  # example对象对label和img数据进行封装

# 然后把这个封装好的example写入到文件里
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")
writer.write(example.SerializeToString())
writer.close()

这个过程很简单,但是有一个地方需要说明一下。构建example的时候,这个tf.train.Feature()函数可以接收三种数据:

  • bytes_list: 可以存储string 和byte两种数据类型。
  • float_list: 可以存储float(float32)与double(float64) 两种数据类型。
  • int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。

对于只有一个值(比如label)可以用float_list或int64_list,而像图片、视频、embedding这种列表型的数据,通常转化为bytes格式储存。下面把整个过程梳理一遍:

代码语言:javascript复制
import numpy as np
from PIL import image
import tensorflow as tf

index_file = "./index_file.csv"
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")    # 打开文件

index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    # 获取label和图片的numpy形式
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    
    # 将label和img捏在一起
    example = tf.train.Example(features=tf.train.Features(feature={
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
    }))  # example对象对label和img数据进行封装

    # 将构建好的 example 写入到 TFRecord
    writer.write(example.SerializeToString())
# 关闭文件
writer.close()

这就是制作TFRecord的流程啦。这里买下一个伏笔,本例中图片的numpy是np.uint8格式的,每个像素点取值在[0, 255]之间。

代码语言:javascript复制
print(img.dtype)
# 输出 dtype('uint8')

2. 读取

TFRecord做好了,要怎么读取呢?我们可以通过tf.data来生成一个迭代器,每次调用都返回一个大小为batch_size的batch。

代码语言:javascript复制
def read_and_decode(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

这个read_and_decode过程基本上是通用的,直接复制过去改一改就行。对于不同的数据,只需要改动_parse_fn函数就可以。这里有一点很重要,就是在_parse_fn函数中,tf.decode_raw的第二个参数(解码格式),必须和保存TFRecord时候的numpy的格式是一样的,否则会报TypeError,我们保存图片时候采用的是np.uint8,这里解码的时候也要用tf.uint8。

接下来我们来试一试把

代码语言:javascript复制
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")
with tf.Session() as sess:
    print(sess.run(batch_features["image"][0]))
    print(sess.run(batch_labels[0]))

3. 使用

会写会读之后,我们来简单尝试下怎么用吧!假设我们要用简单的DNN预测MNIST的label。

代码语言:javascript复制
# 调用 read_and_decode 获取一个 batch 的数据
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")

# input
X = tf.cast(batch_features["image"], tf.float32, name="input_image")
X = tf.reshape(X, [-1, 28*28]) / 255    # 将像素点的值标准化到[0,1]之间
label = tf.one_hot(tf.cast(batch_labels, tf.int32, name="input_label"), depth=10, name="label")

# DNN Layer
deep_inputs = X
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
y = tf.reshape(y_deep, shape=[-1, 10])
pred = tf.nn.softmax(y, name="pred")

# 构建损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=label))
# 构建train_op
train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)

上面就是简单的,通过read_and_decode函数读取数据,并作为DNN模型的输入的例子。下面的代码是输出一下pred,然后训练10个step,在输出新的pred的例子。

代码语言:javascript复制
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(pred))
    for i in range(10):
        print(sess.run(train_op))
    print(sess.run(pred))

这里我们发现,与tf.placeholder不同,如果采用tf.placeholder作为模型的输入,需要在sess.run()的时候手动的设置feed_dict,来喂一个batch的数据;而如果采用TFRecord,每次sess.run()时,根据向前追溯的计算逻辑,都会自动的调用一次read_and_decode获得一个batch的数据,所以就不需要手动feed数据。

0 人点赞