参考 tf.python_io.TFRecordWriter() - 云 社区 - 腾讯云
目录
1、Setup
2、tf.Example
1、Data types for tf.Example
2、Creating a tf.Example message
3、TFRecords format details
4、TFRecord files using tf.data
1、Writing a TFRecord file
2、Reading a TFRecord file
5、TFRecord files in Python
1、Writing a TFRecord file
2、Reading a TFRecord file
6、Walkthrough: Reading and writing image data
1、Fetch the images
2、Write the TFRecord file
3、Read the TFRecord file
为了有效地读取数据,将数据序列化并将其存储在一组文件(每个文件100-200MB)中是很有帮助的,这些文件可以线性读取。如果数据是通过网络传输的,这一点尤其正确。这对于缓存任何数据预处理也很有用。TFRecord格式是一种用于存储二进制记录序列的简单格式。协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由.proto文件定义,这通常是理解消息类型的最简单方法。特遣部队。示例消息(或protobuf)是一种灵活的消息类型,它表示{“string”:value}映射。它被设计为与TensorFlow一起使用,并在更高级别的api(如TFX)中使用。本笔记本将演示如何创建、解析和使用tf。示例消息,然后序列化、写入和读取tf。与.tfrecord文件之间的示例消息。
注意:虽然有用,但这些结构是可选的。不需要将现有代码转换为使用TFRecords,除非使用tf。数据和阅读数据仍然是训练的瓶颈。有关数据集性能技巧,请参阅数据输入管道性能。
1、Setup
代码语言:javascript复制from __future__ import absolute_import, division, print_function, unicode_literals
try:
# %tensorflow_version only exists in Colab.
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
import numpy as np
import IPython.display as display
代码语言:javascript复制ERROR: tensorflow-gpu 2.0.0b1 has requirement tb-nightly<1.14.0a20190604,>=1.14.0a20190603, but you'll have tb-nightly 1.15.0a20190806 which is incompatible.
2、tf.Example
1、Data types for tf.Example
Fundamentally, a tf.Example
is a {"string": tf.train.Feature}
mapping.The tf.train.Feature message type can accept one of the following three types (See the .proto file for reference). Most other generic types can be coerced into one of these:
- tf.train.BytesList (the following types can be coerced)
string
byte
- tf.train.FloatList (the following types can be coerced)
float
(float32
)double
(float64
)
- tf.train.Int64List (the following types can be coerced)
bool
enum
int32
uint32
int64
uint64
以便将标准TensorFlow类型转换为tf。Example-compatible tf.train。功能,您可以使用下面的快捷功能。注意,每个函数都接受一个标量输入值并返回一个tf.train。包含上述三种列表类型之一的功能:
代码语言:javascript复制# The following functions can be used to convert a value to a type compatible
# with tf.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
注意:为了保持简单,本例仅使用标量输入。处理非标量特性的最简单方法是使用tf。serialize_张量将张量转换成二进制字符串。字符串是tensorflow中的标量。使用tf.parse_tensor
将二进制字符串转换回张量。
下面是这些函数如何工作的一些例子。注意不同的输入类型和标准化的输出类型。如果函数的输入类型与上面所述的任何一种可强制类型不匹配,函数将会引发异常(例如_int64_feature(1.0)将会出错,因为1.0是一个浮点数,所以应该与_float_feature函数一起使用):
代码语言:javascript复制print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
代码语言:javascript复制bytes_list {
value: "test_string"
}
bytes_list {
value: "test_bytes"
}
float_list {
value: 2.7182817459106445
}
int64_list {
value: 1
}
int64_list {
value: 1
}
所有原始消息都可以使用.SerializeToString方法序列化为二进制字符串:
代码语言:javascript复制feature = _float_feature(np.exp(1))
feature.SerializeToString()
代码语言:javascript复制b'x12x06nx04Txf8-@'
2、Creating a tf.Example
message
假设您想创建一个tf。来自现有数据的示例消息。实际上,数据集可以来自任何地方,但是创建tf的过程除外。来自单个观察的示例消息将是相同的:
- 在每个观察中,需要将每个值转换为tf.train。使用上面的函数之一,包含3种兼容类型之一的特性。
- 您可以创建一个映射(字典),从特性名称字符串到#1中生成的编码特性值。
- 步骤2中生成的映射被转换为一个功能消息。
在这个笔记本中,您将使用NumPy创建一个数据集。这个数据集将有4个特点:*一个布尔值特性,或真或假,等概率*整数特性均匀随机选择从[0,5]*的字符串生成特性从一个字符串表使用整数特性作为指数*浮动特性从一个独立标准正态distributionConsider样本组成的10000和恒等分布的观察从上述每个发行版:
代码语言:javascript复制# The number of observations in the dataset.
n_observations = int(1e4)
# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)
# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)
# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# Float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)
这些特性都可以强制转换为tf.Example-compatible使用_bytes_feature、_float_feature、_int64_feature中的一个作为示例兼容类型。然后可以创建tf。来自这些编码特性的示例消息:
代码语言:javascript复制def serialize_example(feature0, feature1, feature2, feature3):
"""
Creates a tf.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.Example-compatible
# data type.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
例如,假设您从数据集中有一个观察值,[False, 4, bytes('goat'), 0.9876]。您可以创建并打印tf。使用create_message()为该观察提供示例消息。每个单独的观察结果都将按照上面所述作为一个特性消息来编写。注意tf.Example示例消息只是功能消息的一个包装:
代码语言:javascript复制# This is an example observation from the dataset.
example_observation = []
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
代码语言:javascript复制b'nRnx11nx08feature0x12x05x1ax03nx01x00nx11nx08feature1x12x05x1ax03nx01x04nx14nx08feature2x12x08nx06nx04goatnx14nx08feature3x12x08x12x06nx04[xd3|?'
使用tf.train.Example解码消息。FromString方法。
代码语言:javascript复制example_proto = tf.train.Example.FromString(serialized_example)
example_proto
代码语言:javascript复制features {
feature {
key: "feature0"
value {
int64_list {
value: 0
}
}
}
feature {
key: "feature1"
value {
int64_list {
value: 4
}
}
}
feature {
key: "feature2"
value {
bytes_list {
value: "goat"
}
}
}
feature {
key: "feature3"
value {
float_list {
value: 0.9876000285148621
}
}
}
}
3、TFRecords format details
TFRecord文件包含一系列记录。该文件只能按顺序读取。每个记录包含一个字节字符串,用于数据有效负载,加上数据长度,以及用于完整性检查的CRC32C(使用Castagnoli多项式的32位CRC)散列。每条记录以下列格式储存:
代码语言:javascript复制uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
这些记录被连接在一起产生文件。这里描述CRCs, CRC的掩码为:
代码语言:javascript复制masked_crc = ((crc >> 15) | (crc << 17)) 0xa282ead8ul
注意:没有使用tf.Example的要求TFRecord文件中的例子。Example只是将字典序列化为字节字符串的一种方法。文本行、编码图像数据或序列化张量(使用tf.io)。serialize_tensor, tf.io.parse_tensor转载)。看到特遣部队。io模块提供更多选项。
4、TFRecord files using tf.data
数据模块还提供了在TensorFlow中读写数据的工具。
1、Writing a TFRecord file
将数据放入数据集中最简单的方法是使用from_tensor_sections方法。应用于数组,它返回一个标量数据集:
代码语言:javascript复制tf.data.Dataset.from_tensor_slices(feature1)
代码语言:javascript复制<TensorSliceDataset shapes: (), types: tf.int64>
应用于数组的元组,它返回一个元组数据集:
代码语言:javascript复制features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
代码语言:javascript复制<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
代码语言:javascript复制# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
print(f0)
print(f1)
print(f2)
print(f3)
代码语言:javascript复制tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(b'chicken', shape=(), dtype=string)
tf.Tensor(2.250539710963092, shape=(), dtype=float64)
使用tf.data.Dataset。方法将函数应用于数据集的每个元素。所映射的函数必须在张量流图模式中操作—它必须操作并返回tf.张量。一个非张量函数,比如create_example,可以用tf封装。py_function使其兼容。使用tf。py_function需要指定否则不可用的形状和类型信息:
代码语言:javascript复制def tf_serialize_example(f0,f1,f2,f3):
tf_string = tf.py_function(
serialize_example,
(f0,f1,f2,f3), # pass these args to the above function.
tf.string) # the return type is `tf.string`.
return tf.reshape(tf_string, ()) # The result is a scalar
代码语言:javascript复制tf_serialize_example(f0,f1,f2,f3)
代码语言:javascript复制<tf.Tensor: id=30, shape=(), dtype=string, numpy=b'nUnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x02nx17nx08feature2x12x0bntnx07chickennx14nx08feature3x12x08x12x06nx04xd8x08x10@'>
将此函数应用于数据集中的每个元素:
代码语言:javascript复制serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
代码语言:javascript复制<MapDataset shapes: (), types: tf.string>
代码语言:javascript复制def generator():
for features in features_dataset:
yield serialize_example(*features)
代码语言:javascript复制serialized_features_dataset = tf.data.Dataset.from_generator(
generator, output_types=tf.string, output_shapes=())
代码语言:javascript复制serialized_features_dataset
代码语言:javascript复制<DatasetV1Adapter shapes: (), types: tf.string>
并写入TFRecord文件:
代码语言:javascript复制filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
2、Reading a TFRecord file
您还可以使用tf.data读取TFRecord文件。TFRecordDataset类。有关使用tf使用TFRecord文件的更多信息。数据可以在这里找到。使用TFRecordDatasets对于标准化输入数据和优化性能非常有用。
代码语言:javascript复制filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
代码语言:javascript复制<TFRecordDatasetV2 shapes: (), types: tf.string>
此时,数据集包含序列化的tf.train。消息示例。当对其进行迭代时,将返回这些标量字符串张量。使用.take方法只显示前10条记录。
注意:遍历tf.data.Dataset只在启用紧急执行时工作。
代码语言:javascript复制for raw_record in raw_dataset.take(10):
print(repr(raw_record))
代码语言:javascript复制<tf.Tensor: id=50093, shape=(), dtype=string, numpy=b'nUnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x02nx17nx08feature2x12x0bntnx07chickennx14nx08feature3x12x08x12x06nx04xd8x08x10@'>
<tf.Tensor: id=50094, shape=(), dtype=string, numpy=b'nUnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x02nx17nx08feature2x12x0bntnx07chickennx14nx08feature3x12x08x12x06nx04x9bUIxbf'>
<tf.Tensor: id=50095, shape=(), dtype=string, numpy=b'nQnx11nx08feature0x12x05x1ax03nx01x00nx11nx08feature1x12x05x1ax03nx01x00nx13nx08feature2x12x07nx05nx03catnx14nx08feature3x12x08x12x06nx04x9ex9a$>'>
<tf.Tensor: id=50096, shape=(), dtype=string, numpy=b'nSnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x03nx15nx08feature2x12tnx07nx05horsenx14nx08feature3x12x08x12x06nx04xccxc4x82xbf'>
<tf.Tensor: id=50097, shape=(), dtype=string, numpy=b'nSnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x03nx15nx08feature2x12tnx07nx05horsenx14nx08feature3x12x08x12x06nx04k~x1dxc0'>
<tf.Tensor: id=50098, shape=(), dtype=string, numpy=b'nUnx11nx08feature0x12x05x1ax03nx01x00nx11nx08feature1x12x05x1ax03nx01x02nx17nx08feature2x12x0bntnx07chickennx14nx08feature3x12x08x12x06nx04`xacxab?'>
<tf.Tensor: id=50099, shape=(), dtype=string, numpy=b'nRnx14nx08feature2x12x08nx06nx04goatnx14nx08feature3x12x08x12x06nx04xdcx1b3>nx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x04'>
<tf.Tensor: id=50100, shape=(), dtype=string, numpy=b'nRnx11nx08feature0x12x05x1ax03nx01x01nx11nx08feature1x12x05x1ax03nx01x04nx14nx08feature2x12x08nx06nx04goatnx14nx08feature3x12x08x12x06nx04xb5x96x1b?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'nQnx11nx08feature0x12x05x1ax03nx01x00nx11nx08feature1x12x05x1ax03nx01x00nx13nx08feature2x12x07nx05nx03catnx14nx08feature3x12x08x12x06nx047x11<xbf'>
<tf.Tensor: id=50102, shape=(), dtype=string, numpy=b'nSnx15nx08feature2x12tnx07nx05horsenx14nx08feature3x12x08x12x06nx04xecx96xb1=nx11nx08feature0x12x05x1ax03nx01x00nx11nx08feature1x12x05x1ax03nx01x03'>
可以使用下面的函数解析这些张量。注意feature_description在这里是必要的,因为数据集使用图形执行,并且需要这个描述来构建它们的形状和类型签名:
代码语言:javascript复制# Create a description of the features.
feature_description = {
'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
def _parse_function(example_proto):
# Parse the input `tf.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
另外,使用tf。一次性解析整个批处理的解析示例。使用tf.data.Dataset将此函数应用于数据集中的每个项tf.data.Dataset.map的方法:
代码语言:javascript复制parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
代码语言:javascript复制<MapDataset shapes: {feature1: (), feature3: (), feature2: (), feature0: ()}, types: {feature1: tf.int64, feature3: tf.float32, feature2: tf.string, feature0: tf.int64}>
使用快速执行来显示数据集中的观察结果。这个数据集中有10,000个观察值,但是您将只显示前10个。数据显示为特征字典。每个项目都是tf。张量,该张量的numpy元表示特征值:
代码语言:javascript复制for parsed_record in parsed_dataset.take(10):
print(repr(parsed_record))
代码语言:javascript复制{'feature1': <tf.Tensor: id=50135, shape=(), dtype=int64, numpy=2>, 'feature0': <tf.Tensor: id=50134, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50136, shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: id=50137, shape=(), dtype=float32, numpy=2.2505398>}
{'feature1': <tf.Tensor: id=50139, shape=(), dtype=int64, numpy=2>, 'feature0': <tf.Tensor: id=50138, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50140, shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: id=50141, shape=(), dtype=float32, numpy=-0.7864625>}
{'feature1': <tf.Tensor: id=50143, shape=(), dtype=int64, numpy=0>, 'feature0': <tf.Tensor: id=50142, shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: id=50144, shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: id=50145, shape=(), dtype=float32, numpy=0.16074607>}
{'feature1': <tf.Tensor: id=50147, shape=(), dtype=int64, numpy=3>, 'feature0': <tf.Tensor: id=50146, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50148, shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: id=50149, shape=(), dtype=float32, numpy=-1.0216308>}
{'feature1': <tf.Tensor: id=50151, shape=(), dtype=int64, numpy=3>, 'feature0': <tf.Tensor: id=50150, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50152, shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: id=50153, shape=(), dtype=float32, numpy=-2.460841>}
{'feature1': <tf.Tensor: id=50155, shape=(), dtype=int64, numpy=2>, 'feature0': <tf.Tensor: id=50154, shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: id=50156, shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: id=50157, shape=(), dtype=float32, numpy=1.341198>}
{'feature1': <tf.Tensor: id=50159, shape=(), dtype=int64, numpy=4>, 'feature0': <tf.Tensor: id=50158, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50160, shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: id=50161, shape=(), dtype=float32, numpy=0.17491096>}
{'feature1': <tf.Tensor: id=50163, shape=(), dtype=int64, numpy=4>, 'feature0': <tf.Tensor: id=50162, shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: id=50164, shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: id=50165, shape=(), dtype=float32, numpy=0.60776836>}
{'feature1': <tf.Tensor: id=50167, shape=(), dtype=int64, numpy=0>, 'feature0': <tf.Tensor: id=50166, shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: id=50168, shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: id=50169, shape=(), dtype=float32, numpy=-0.7346377>}
{'feature1': <tf.Tensor: id=50171, shape=(), dtype=int64, numpy=3>, 'feature0': <tf.Tensor: id=50170, shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: id=50172, shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: id=50173, shape=(), dtype=float32, numpy=0.08671364>}
这里的tf.parse_example函数解压缩。将示例字段转换为标准张量。
5、TFRecord files in Python
tf.io模块还包含用于读取和写入TFRecord文件的纯python函数。
1、Writing a TFRecord file
接下来,将10,000个观察结果写入test.tfrecord文件。每次观测都转换为tf。示例消息,然后写入文件。然后可以验证文件测试。已创建tfrecord:
代码语言:javascript复制# Write the `tf.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
代码语言:javascript复制!du -sh {filename}
代码语言:javascript复制984K test.tfrecord
2、Reading a TFRecord file
这些序列化的张量可以很容易地使用tf.train.Example.ParseFromString解析:
代码语言:javascript复制filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
代码语言:javascript复制<TFRecordDatasetV2 shapes: (), types: tf.string>
代码语言:javascript复制for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
代码语言:javascript复制features {
feature {
key: "feature0"
value {
int64_list {
value: 1
}
}
}
feature {
key: "feature1"
value {
int64_list {
value: 2
}
}
}
feature {
key: "feature2"
value {
bytes_list {
value: "chicken"
}
}
}
feature {
key: "feature3"
value {
float_list {
value: 2.250539779663086
}
}
}
}
6、Walkthrough: Reading and writing image data
这是一个如何使用TFRecords读写图像数据的示例。这样做的目的是显示如何端到端输入数据(在本例中是图像)并将数据写入TFRecord文件,然后读取文件并显示图像。例如,如果希望在同一个输入数据集上使用多个模型,这将非常有用。它可以被预处理成TFRecords格式,而不是存储原始的图像数据,并且可以用于所有进一步的处理和建模。首先,让我们下载这张猫在雪地里的照片和这张正在建设中的纽约威廉斯堡大桥的照片。
1、Fetch the images
代码语言:javascript复制cat_in_snow = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
代码语言:javascript复制Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
代码语言:javascript复制display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
代码语言:javascript复制display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
2、Write the TFRecord file
与前面一样,将特性编码为与tf.Example兼容的类型。它存储原始图像字符串特性,以及高度、宽度、深度和任意标签特性。后者用于在编写文件时区分cat图像和桥接图像。cat图像使用0,桥梁图像使用1:
代码语言:javascript复制image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
代码语言:javascript复制# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
image_shape = tf.image.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('n')[:15]:
print(line)
print('...')
代码语言:javascript复制features {
feature {
key: "depth"
value {
int64_list {
value: 3
}
}
}
feature {
key: "height"
value {
int64_list {
value: 213
}
...
注意,所有特性现在都存储在tf。Example消息示例。接下来,对上面的代码进行函数化,并将示例消息写入名为images.tfrecords的文件中:
代码语言:javascript复制# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
代码语言:javascript复制!du -sh {record_file}
代码语言:javascript复制36K images.tfrecords
3、Read the TFRecord file
现在有了文件图像。tfrecords——现在可以遍历其中的记录来读取所写的内容。假设在本例中,您将只复制图像,那么您需要的惟一特性就是原始图像字符串。使用上面描述的getter方法提取它,即example.features.feature['image_raw'].bytes_list.value[0]。你还可以使用标签来确定哪条记录是猫,哪条记录是桥:
代码语言:javascript复制raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
代码语言:javascript复制<MapDataset shapes: {label: (), width: (), image_raw: (), height: (), depth: ()}, types: {label: tf.int64, width: tf.int64, image_raw: tf.string, height: tf.int64, depth: tf.int64}>
Recover the images from the TFRecord file:
代码语言:javascript复制for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))