Tensorflow读取数据(一)

2021-01-26 22:41:16 浏览数 (1)

数据算法是深度学习最重要的两大块。而更基础的首先是要熟练掌握一个框架来支撑算法的执行。 我个人使用最多的是tensorflow平台。就从最基础的数据输入开始记录吧。

AI算法基本流程

个人总结的AI项目基础流程(除开更复杂的工程化工作) (1)数据预处理:get每个迭代的输入和标签。图像,音频,文本对数据处理方式又各有不同;不同的需求对标签的格式也不相同。 (2)算法建模:设计网络模型,输入:训练数据;输出:预测值 (3)优化参数:通过输出和真实label设计loss,还需要设计一个优化算法,让网络参数去学习得到最优解 (4)迭代训练:不断更新数据,在大数据上优化参数 (5)保存网络参数以及设计评价指标 以上步骤还只是算法部分,而且每个模块都有很可以展开出很多内容,其他更多工程上模块就不提了~

数据模块

今天先从数据模块下手。在训练过程中,我们对需求就是要不断的从所有数据中取一个batch数据输入到模型中。如果是python,那比较简单,伪代码如下:

代码语言:javascript复制
#随机从datas里面抽取batch_size个数据
def get_batch(batch_size,datas):
batch_datas = []
datas.shuffle()
for i in range(batch_size):
	batch_datas.append(datas[i])
return batch_datas

但是在tensorflow框架中,我们就要利用它的优势来进行数据的读取。今天先介绍通过tf.Coordinatortf.QueueRunner来利用多线程管理数据。 tf.QueueRunner()就是负责开启线程以及线程队列 tf.train.Coordinator()就是创建一个线程管理器,管理我们开启的线程

准备数据

我们先准备两类图片数据,结构如下

为了方便,我们建立数据集文件夹Images,里面两类图片数据1,2。 然后我们生成一个文件列表,代码如下:

代码语言:javascript复制
# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:35
# @Author  : LanguageX
import os

root_dir = os.getcwd()
fw = open("./train.txt","w")
for root, dirs, files in os.walk(root_dir):
    for file in files:
        if file.endswith("jpg") or file.endswith("png"):
            filename = os.path.join(root, file)
            class_name = filename.split("/")[-2]
            print(class_name,filename)
            fw.write(filename " " class_name "n")

目的就是生成train.txt文本列表(格式:图片路径–类别)

数据准备好了~下面就可以开始实现取数据的代码了~

代码框架比较简单,添加了比较详细的注释,就直接上代码吧:

代码语言:javascript复制
# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:24
# @Author  : LanguageX

import tensorflow as tf
import os

class DataReader:

    def get_data_lines(self, filename):
        with open(filename) as txt_file:
            lines = txt_file.readlines()
            return lines

    def gen_datas(self, train_files):
        paths = []
        labels = []
        for line in train_files:
            line = line.replace("n","")
            path, label = line.split(" ")
            paths.append(path)
            labels.append(label)
        return paths, labels

    def __init__(self,root_dir,train_filepath,batch_size,img_size):
         self.dir = root_dir
         self.batch_size = batch_size
         self.img_size = img_size
         #读取生成的path-label列表
         self.train_files = self.get_data_lines(train_filepath)
         #获取对应的paths和labels
         self.paths,self.labels = self.gen_datas(self.train_files)
         self.data_nums  = len(self.train_files)



    def get_batch(self, batch_size):
        self.paths = tf.cast(self.paths, tf.string)
        self.labels = tf.cast(self.labels, tf.string)
        #slice_input_producer构建了取数据队列
        input_queue = tf.train.slice_input_producer([self.paths, self.labels], num_epochs=10, shuffle=True)

        # 从文件名称队列中读取文件放入文件队列
        image_batch, label_batch= tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64,
                                                  allow_smaller_final_batch=False)

        return image_batch, label_batch



if __name__ == '__main__':
    root_dir = "../images/"
    filename = "./images/train.txt"
    batch_size = 4
    image_size = 256
    dataset = DataReader(root_dir,filename,batch_size,image_size)

    images,labels = dataset.get_batch(batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        #coord线程管理器
        coord = tf.train.Coordinator()
        #tf的线程队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(5):
            _imgs,_labesl = sess.run([images,labels])
            print("_imgs ", _imgs)
            print("_labes ", _labesl)
        #通知线程停止
        coord.request_stop()
        coord.join(threads)
        sess.close()

运行就可以在每个迭代获取到batch_size个数据了。基本本文获取数据的基本框架,其他任务的数据读取都可以举一反三添加业务需求了~

0 人点赞