TFRecord读写简介

2022-08-11 16:51:17 浏览数 (1)

为了高效地读取数据,比较有帮助的一种做法是对数据进行序列化并将其存储在一组可线性读取的文件(每个文件 100-200MB)中。这尤其适用于通过网络进行流式传输的数据。这种做法对缓冲任何数据预处理也十分有用。TFRecord 格式是一种用于存储二进制记录序列的简单格式。

TFRecordTFRecord

1. 写入TFRecord

TFRecord写入流程TFRecord写入流程
  • 特征数据
代码语言:python代码运行次数:0复制
feature_data = {
    'name': 'xiaoming',
    'age': 20,
    'height': 172.8,
    'scores': [[120,130,140],[82,95,43]]
}
  • tf.Example 消息(或 protobuf)是一种灵活的消息类型,表示 {"string": value} 映射。它专为 TensorFlow 而设计,并被用于 TFX 等高级 API。
代码语言:python代码运行次数:0复制
example_proto = tf.train.Example(
    features=tf.train.Features(feature={
        # 将标准 TensorFlow 类型转换为兼容 tf.Example 的 tf.train.Feature  
        'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'xiaoming'])),
        'age': tf.train.Feature(int64_list=tf.train.Int64List(value=[20])),
        'height': tf.train.Feature(float_list=tf.train.FloatList(value=[172.8])),
        'scores': tf.train.Feature(bytes_list=tf.train.BytesList(
        # 要处理非标量特征,最简单的方法是使用 tf.io.serialize_tensor 将张量转换为二进制字符串
        value=[tf.io.serialize_tensor([[120,130,140],[82,95,43]]).numpy()]))
    })
)
  
""" 输出结果:  
features {
    feature {
        key: "age"
        value {
            int64_list {
                value: 20
            }
        }
    }
    feature {
        key: "height"
        value {
            float_list {
                value: 172.8000030517578
            }
        }
    }
    feature {
        key: "name"
        value {
            bytes_list {
                value: "xiaoming"
            }
        }
    }
    feature {
        key: "scores"
        value {
            bytes_list {
                value: "100322102202100222021003"30x000000202000000214000000R000000_000000 000000"
            }
        }
    }
}
"""
  • 使用 .SerializeToString 方法将所有协议消息序列化为二进制字符串
代码语言:python代码运行次数:0复制
serialized_example = example_proto.SerializeToString()

# 输出结果:b'nnn4nx06scoresx12*n(n&x08x03x12x08x12x02x08x02x12x02x08x03"x18xx00x00x00x82x00x00x00x8cx00x00x00Rx00x00x00_x00x00x00 x00x00x00nx14nx04namex12x0cnnnx08xiaomingnx12nx06heightx12x08x12x06nx04xcdxcc,Cnx0cnx03agex12x05x1ax03nx01x14'
  • Write TFRecord
代码语言:python代码运行次数:0复制
with tf.io.TFRecordWriter(file_path) as writer:
    writer.write(serialized_example)

2. 读取TFRecord

TFRecord读取流程TFRecord读取流程
  • feature_description 是必需的,因为数据集使用计算图执行,并且需要以下描述来构建它们的形状和类型签名
代码语言:python代码运行次数:0复制
feature_description = {    
    'name': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'age': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'height': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
    'scores': tf.io.FixedLenFeature([], tf.string, default_value=''),
}
  • 解析
代码语言:python代码运行次数:0复制
def parse_from_example(serialized_example):    
    # tf.parse_example 函数会将 tf.Example 字段解压缩为标准张量
    feature_data = tf.io.parse_single_example(serialized_example, feature_description)
    # 使用 tf.io.parse_tensor 可将二进制字符串转换回张量
    feature_data['scores'] = tf.reshape(tf.io.parse_tensor(feature_data['scores'], out_type=tf.int32), (2, 3))
    return feature_data

parse_from_example(serialized_example)
"""输出结果:
{'age': <tf.Tensor: id=15, shape=(), dtype=int64, numpy=20>, 'height': <tf.Tensor: id=16, shape=(), dtype=float32, numpy=172.8>, 'name': <tf.Tensor: id=17, shape=(), dtype=string, numpy=b'xiaoming'>, 'scores': <tf.Tensor: id=21, shape=(2, 3), dtype=int32, numpy=
array([[120, 130, 140],
        [ 82,  95,  43]], dtype=int32)>}
"""
  • Read TFRecord
代码语言:python代码运行次数:0复制
# 使用 tf.data.Dataset.map 方法可将函数应用于 Dataset 的每个元素
# Tips: You can convert tensor into numpy array using tensor.numpy(), But you can't do the same in case of MapDataset. Try tf.numpy_function / tf.py_function
dataset = tf.data.TFRecordDataset(file_path).map(parse_from_example)

0 人点赞