Tensorflow读取数据(二)

2021-01-29 10:36:22 浏览数 (1)

上一篇介绍了利用tensorflow的QueueRunner和coord进行数据读取的简单框架。

其实在tf1.4之后新增了tf.data.Dataset,官方推出的一些源码也都转为使用dataset的API来进行数据读取,所以今天就来介绍下利用dataset来进行数据读取。

项目中一般使用最多的就是datasetiterator,关于dataset官方提供了API使用和介绍:https://github.com/tensorflow/docs/blob/r1.8/site/en/api_docs/python/tf/data/Dataset.md

https://zhuanlan.zhihu.com/p/30751039这篇也介绍的比较详细。

我就直接用代码来介绍下如何使用tf.data.dataset读取数据。

还是使用上一篇的数据结构和代码框架,只是把QueueRunner和coord相关的代码删除,替换为tf.data.dataset的API

代码语言:javascript复制
# -*- coding: utf-8 -*-
# @Time    : 2019-10-08 21:24
# @Author  : LanguageX

import tensorflow as tf
import os

class DataReader:

    def get_data_lines(self, filename):
        with open(filename) as txt_file:
            lines = txt_file.readlines()
            return lines

    def gen_datas(self, train_files):
        paths = []
        labels = []
        for line in train_files:
            line = line.replace("n","")
            path, label = line.split(" ")
            paths.append(path)
            labels.append(label)
        return paths, labels

    def __init__(self,root_dir,train_filepath,batch_size,img_size):
         self.dir = root_dir
         self.batch_size = batch_size
         self.img_size = img_size
         #读取生成的path-label列表
         self.train_files = self.get_data_lines(train_filepath)
         #获取对应的paths和labels
         self.paths,self.labels = self.gen_datas(self.train_files)
         self.data_nums  = len(self.train_files)

    # 把图片文件解码并进行预处理,在这里进行你需要的数据增强代码
    def preprocess(self, filepath, label):
        img = tf.read_file(filepath)
        img = tf.image.decode_jpeg(img, channels=3)
        shape = img.get_shape()

        image_resized = tf.image.resize_images(img, [128, 128])
        return filepath,image_resized, label


    def get_batch(self, batch_size):

        self.paths = tf.cast(self.paths, tf.string)
        self.labels = tf.cast(self.labels, tf.string)
        #利用tf.data.Dataset,输出dataset的一个元素的格式:(path,label)
        dataset = tf.data.Dataset.from_tensor_slices((self.paths,self.labels))
        #通过preprocess后,现在的dataset_prs的一个元素格式:(path,image,label)
        #这个map函数比较强大,参数是一个函数,在函数里面可以为所欲为
        dataset_prs = dataset.map(self.preprocess)
        #通过dataset的一系列API,随机打乱,返回一个batch的数据,数据集重复5次
        dataset_prs = dataset_prs.shuffle(buffer_size=20).batch(batch_size).repeat(5)
        #创建一个one shot iterator
        _iterator = dataset_prs.make_one_shot_iterator()
        #利用迭代器返回下一个batch的数据
        bathc_data = _iterator.get_next()
        return bathc_data



if __name__ == '__main__':
    root_dir = "../images/"
    filename = "./images/train.txt"
    batch_size = 4
    image_size = 256
    dataset = DataReader(root_dir,filename,batch_size,image_size)

    bathc_data = dataset.get_batch(batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        for i in range(1):
            _bathc_data = sess.run(bathc_data)
            for i in range(batch_size):
                print("_name", _bathc_data[0][i])
                print("_image", _bathc_data[1][i].shape)
                print("_label", _bathc_data[2][i])

运行下,我们可以输入图片路径,数据,标签~

和上一篇对比,我们的大致流程没有修改,只是替换使用了高阶API读取数据而已,因为没在大数据集上进行性能实验对比,所以不敢说在同样的数据格式下tf.dataset会快些,不过在代码使用上确实便捷不少,在最新的tf2.0对dataset有更进一步的优化尤其对文本任务。

我的博客即将同步至腾讯云 社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=2zfyzsld89q8w

0 人点赞