【技术分享】深入了解tensorflow模型存储格式

2020-03-18 15:06:43 浏览数 (1)

本文原作者:岳夕涵,经授权后发布。

导语

做模型的同学基本都会使用tensorflow,不知道大家是否会像我一样对tensorflow的模型存储感到疑惑:各种模型保存的方法、保存出的模型文件名称和结构还不一样、加载模型的时候有的需要重新定义一遍计算图而有的不需要、有的格式tfserving能用有的不能用。这篇文章会带大家了解每个模型文件分别包含什么内容、计算图是以什么样的形式保存在文件中的。

以下讨论的api都是基于tensorflow1.15版本。

1 保存模型

先定义一段模型计算,然后用两种不同的格式进行保存。 定义计算如下:

代码语言:javascript复制
import tensorflow as tf
sess = tf.Session()
x = tf.placeholder(tf.float32, [None, 100], name='x')
with tf.variable_scope('layer1') as scope:
    W = tf.get_variable(name='W', shape=[100, 100], dtype=tf.float32)
    b = tf.constant(list(range(100)), dtype=tf.float32, name='b')
    y = tf.matmul(x, W, name='y')   b

with tf.variable_scope('layer2') as scope:
    W = tf.get_variable(name='W', shape=[100, 100], dtype=tf.float32)
    b = tf.constant([0] * 100, dtype=tf.float32, name='b')
    z = tf.matmul(y, W, name='z')   b

sess.run(tf.global_variables_initializer())

1.1 用tf.train.Saver保存

代码语言:javascript复制
tf.train.Saver().save(sess, 'example_0/model')

这是最常用到也最方便的保存方式,得到如下模型文件:

代码语言:javascript复制
example_0
├── checkpoint
├── model.data-00000-of-00001
├── model.index
└── model.meta

checkpoint文件只有两行文本内容

代码语言:javascript复制
model_checkpoint_path: "model"
all_model_checkpoint_paths: "model"

这个文件的意义很好理解,就是把model这个我们起的名字保存起来。而且上面的save方法可以传入一个叫global_step的参数,每次save的时候都会生成不同的文件,checkpoint会指向最后保存的文件。

1.2 用tf.saved_model保存

代码语言:javascript复制
from tensorflow.saved_model.utils import build_tensor_info
from tensorflow.saved_model.signature_def_utils import build_signature_def

builder = tf.saved_model.builder.SavedModelBuilder('example_1')
signature_inputs = {"x": build_tensor_info(x)}
signature_outpts = {
    "y": build_tensor_info(y),
    "z": build_tensor_info(z)
}
signature_def = build_signature_def(signature_inputs, signature_outpts, "default")
builder.add_meta_graph_and_variables(sess, ["serve"], signature_def_map={"default": signature_def})
builder.save(as_text=False)

其中signature_def_map参数也可以不填。得到如下模型文件:

代码语言:javascript复制
example_1
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

tensorflow还支持其他函数来进行模型保存,如采用Estimator的export_savedmodel方法,但保存下来的模型格式总是包含在上面两种格式之内。

2 保存模型参数的文件

两种模型存储格式中分别都包含了*.data-00000-of-00001和*.index的文件。

这两个文件都是二进制文件,*.index是参数索引文件,保存着参数的基本信息,但不保存参数的值;*.data是参数值文件。

2.1 文件内容

*.index文件是采用特殊的拼接格式将多个protobuf拼接得到的。第一个pb是BundleHeaderProto,记录了一些基本信息;而后所有的pb都是BundleEntryProto,记录着每一个模型参数

代码语言:javascript复制
message BundleHeaderProto {
  int32 num_shards = 1; # 暂时看不出用途
  enum Endianness {
    LITTLE = 0;
    BIG = 1;
  }
  Endianness endianness = 2; # 大端序还是小端序
  VersionDef version = 3; # 版本
}
message BundleEntryProto {
  DataType dtype = 1; # 类型
  TensorShapeProto shape = 2; # 形状
  int32 shard_id = 3; # 暂时看不出用途
  int64 offset = 4; # 在数据文件中的偏移地址
  int64 size = 5; # 在数据文件中的大小
  fixed32 crc32c = 6; # 数据文件内容校验码
  repeated TensorSliceProto slices = 7; # 暂时看不出用途
}

解析上面模型文件中的index文件得到如下打印:

代码语言:javascript复制
num_shards: 1 version { producer: 1 }
layer1/W
dtype: DT_FLOAT shape { dim { size: 100 } dim { size: 100 } } size: 40000 crc32c: 2796859335
layer2/W
dtype: DT_FLOAT shape { dim { size: 100 } dim { size: 100 } } offset: 40000 size: 40000 crc32c: 2385875586

从上面的打印结果指出在我机器上的*.data文件是以小端序保存着layer1/W和layer2/W的值:layer1/W是从偏移量0开始的40000个字节,layer2/W是从偏移量40000开始的40000个字节。

由此可见,model.index和model.data只保留了参数的信息。上面模型里的x属于输入、y和z属于中间量,两个b是常数,它们都不是参数,因此在index文件和data文件中不出现。

2.2 具体格式

本节深入到index文件的拼接格式细节,可以略过

index文件其实存的是键值对,键就是变量名称,值就是protobuf。但index文件格式极其复杂,给键加了两层索引、还有各种压缩和校验的trick。

其实tf模型的参数数量并不多,并没有对索引性能和文件大小那么强的要求。个人感觉用简单的格式就能应付了,比如整个文件就是一个protobuf,这样扩展性更好且更易理解。估计是tf团队没多想就直接拿谷歌内部的实现代码直接用了。

最外层

文件的最后8个字节是一个魔数:0xdb4775248b80fb57ull。

紧接着魔数的倒数20个字节会解析成两个BlockHandle,BlockHandle存储着64位的offset和size。offset和size的在文件中的编码方式比较神奇,首先看字节的最高位是1还是0,如果是1那表明后面字节还有值,如果是0那表明后面字节没值了,然后把这些字节每个的后7为拼接起来。举例:

代码语言:javascript复制
0b11101110 0b10000100 0b00000120 0b00101000
前3个字节组成一个数0b1101110,0000100,0000120,后一个字节组成一个数0b0101000。

meta_index_handle似乎没用,index_handle指向了数据段中第一层索引对应的Block块。

第一层Block块

Block块的长度由BlockHandle的size决定,在size之后的5个字节分别表示压缩格式和校验码,这也很好理解。

Block的数据区紧密排布着一个个的数据块,每个数据块先按照字节高位1或0的方式解析出3个数shared、non_shared、value_length。shared表示key可与前一变量共享的位数,non_shared就是key的长度,比如"layer1/W"和"layer2/W"可以共享前5位的字符,所以在存储"layer2/W"的时候就只需设shared为5,非共享的部分就是"2/W"存在key处,non_shared值为3。value_length是value的长度,在第一层block块,value又会被解析成BlockHandle,用于指向第二层索引对应的block。

数据区的数据块的key是按照字母表顺序依次排列的,索引区中的每一个索引都指向某个数据块的起始位置且严格递增,因此可以借助索引进行二分查找。

第二层Block块

两层Block块类似于B 树的多层索引结构,第二层Block块跟第一层结构完全一样,唯一的区别是第二层Block块的value是定义参数的protobuf。

3 保存计算图的文件

文章第一章的两种保存格式区别就在于model.meta和saved_model.pb。这两者都保存了模型的计算图,而且这两者都是protobuf文件。

3.1 区别

先看一下二者的区别,saved_model.pb的protobuf定义是SavedModel,如下所示

代码语言:javascript复制
message SavedModel {
  int64 saved_model_schema_version = 1; # 目前好像都设成了1
  repeated MetaGraphDef meta_graphs = 2; # 计算图
}

而model.meta对应的protobuf定义就是上面定义中的MetaGraphDef。可见其实区别是很小的。

还有在第一章,发现第二种保存方式可以添加signature_def,而第一种方式不能添加。

代码语言:javascript复制
builder.add_meta_graph_and_variables(sess, ["serve"], signature_def_map={"default": signature_def})

signature_def的作用是指明计算图中的输入和输出,是专门提供给tfserving用的,saved_model.pb就是专门用来给tfserving加载的格式。所以即便signature_def是定义在第一种保存方式也能处理的MetaGraphDef中的,第一种保存方式也没有提供添加signature_def定义的接口。去看tensorflow的提交历史也能发现SavedModel和signature_def是在同一次提交中加入到tensorflow项目中的。

3.2 计算图的存储

本节将深入计算图是如何在protobuf中存储的。

首先看一下MetaGraphDef的定义,其中图就存储在graph_def中。

代码语言:javascript复制
message MetaGraphDef {
  MetaInfoDef meta_info_def = 1; // 版本、算子等信息
  GraphDef graph_def = 2; // 图定义
  SaverDef saver_def = 3; // 指定存取相关的节点和参数
  map<string, CollectionDef> collection_def = 4; // 定义了可训练的节点集合、要保存的节点集合
  map<string, SignatureDef> signature_def = 5; // 输入输出定义
  repeated AssetFileDef asset_file_def = 6; // 不太明确作用
  SavedObjectGraph object_graph_def = 7; // 不太明确作用
}
message GraphDef {
  repeated NodeDef node = 1; // 节点定义
  VersionDef versions = 4; // 版本信息
  int32 version = 3 [deprecated = true];
  FunctionDefLibrary library = 2; // 自定义函数
};
输入节点的定义

输入节点包含了name、op这两个基本信息,同时在attr中还包含了数据类型和形状,很好理解。

代码语言:javascript复制
node {
 name: "x"
 op: "Placeholder" 
 attr { key: "_output_shapes" value { list { shape { dim { size: -1 } dim { size: 100 } } } } }
 attr { key: "dtype" value { type: DT_FLOAT } }
 attr { key: "shape" value { shape { dim { size: -1 } dim { size: 100 } } } }
}
常量节点的定义

常量节点除了输入节点的特征之外,还额外多了一个value特征,在其中存储了常量节点的值

代码语言:javascript复制
node {
 name: "layer1/b" 
 op: "Const" 
 attr { key: "_output_shapes" value { list { shape { dim { size: 100 } } } } } 
 attr { key: "dtype" value { type: DT_FLOAT } } 
 attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 100 } } tensor_content: "000000000000200?000000@0000@...." } } }
}
变量节点的定义

变量节点跟上面二者的区分不大

代码语言:javascript复制
node { 
 name: "layer1/W" 
 op: "VariableV2" 
 attr { key: "_class" value { list { s: "loc:@layer1/W" } } } 
 attr { key: "_output_shapes" value { list { shape { dim { size: 100 } dim { size: 100 } } } } } 
 attr { key: "container" value { s: "" } } 
 attr { key: "dtype" value { type: DT_FLOAT } } 
 attr { key: "shape" value { shape { dim { size: 100 } dim { size: 100 } } } } 
 attr { key: "shared_name" value { s: "" } } 
}
中间节点的定义

layer1/y就是一个中间节点,在计算图中我们定义它是由x和layer1/W做矩阵乘法得到的。

代码语言:javascript复制
 node {
  name: "layer1/y" 
  op: "MatMul" 
  input: "x" 
  input: "layer1/W/read" 
  attr { key: "T" value { type: DT_FLOAT } } 
  attr { key: "_output_shapes" value { list { shape { dim { size: -1 } dim { size: 100 } } } } } 
  attr { key: "transpose_a" value { b: false } } 
  attr { key: "transpose_b" value { b: false } } 
} 

node { 
 name: "layer1/W/read" 
 op: "Identity" 
 input: "layer1/W" 
 attr { key: "T" value { type: DT_FLOAT } } 
 attr { key: "_class" value { list { s: "loc:@layer1/W" } } } 
 attr { key: "_output_shapes" value { list { shape { dim { size: 100 } dim { size: 100 } } } } } 
}

从上面就可以看出layer1/y节点有两个input,x我们是知道的,那么layer1/W/read又是什么呢?进一步看,layer1/W/read也是一个中间节点,其输入正是layer1/W,op是Identity同等变换。

同时,tf.matmul方法有两个参数transpose_a和transpose_b,默认为False,同样体现在了节点的定义中。

计算图正是通过这种节点定义的方式,用input属性将节点关联起来,从而形成了从输入到输出的有向无环图。

回过头去看x的定义,shape很好理解,就是x的形状是两维,第一维维度待定,第二维为100维。但有一个疑惑的地方,_output_shapes为什么是一个list。其实可以看到每个节点的输出都是一个list,虽然大部分节点的输出的list都只有一个元素,但也有少量节点会输出两个以上的元素,比如分支操作cond/Switch。当其作为输入出现在其他节点input中时,如果input中只有名称,那默认就是list中的第一个元素。如果显式的在input中指定,比如"cond/Switch:1",那么就是cond/Switch输出中的第二个元素。

基本的节点能组成复杂的操作,很多tensorflow在python代码中的函数,如tf.nn.moments,其实是由很多如加法、乘法、求平均等op节点构成的,在pb中并不会出现一个op是moments的节点。所以tensorflow虽然提供了种类繁多的运算,但其实在底层实现的op节点并不算多。

特殊节点

甚至连模型参数的加载也是通过节点定义的。

代码语言:javascript复制
node {
 name: "save/restore_all"
 op: "NoOp"
 input: "^save/Assign"
 input: "^save/Assign_1"
}

save/restore_all的input中,节点名称出现了奇怪的开头"^",该符号叫做控制符,表示如果要执行save/restore_all,则要在之前先执行save/Assign,save/Assign_1。save/restore_all节点本身没有操作,所以起的作用是唤起save/Assign和save/Assign_1的执行。

代码语言:javascript复制
node {
 name: "save/Assign_1" 
 op: "Assign" 
 input: "layer2/W" 
 input: "save/RestoreV2:1" 
 attr { key: "T" value { type: DT_FLOAT } } 
 attr { key: "_class" value { list { s: "loc:@layer2/W" } } } 
 attr { key: "_output_shapes" value { list { shape { dim { size: 100 } dim { size: 100 } } } } } 
 attr { key: "use_locking" value { b: true } } 
 attr { key: "validate_shape" value { b: true } } }
} 

几乎所有的节点都不会影响其input节点,save/Assign_1是少部分例外,其操作会将save/RestoreV2:1的值赋给layer2/W。正因为这一点,所以会对并发的其他线程造成影响,需要加锁。

代码语言:javascript复制
node {
 name: "save/RestoreV2" 
 op: "RestoreV2" 
 input: "save/Const" 
 input: "save/RestoreV2/tensor_names" 
 input: "save/RestoreV2/shape_and_slices" 
 device: "/device:CPU:0" 
 attr { key: "_output_shapes" value { list { shape { unknown_rank: true } shape { unknown_rank: true } } } 
} 
node { 
 name: "save/Const" 
 op: "PlaceholderWithDefault" 
 input: "save/filename" 
 attr { key: "_output_shapes" value { list { shape { } } } } 
 attr { key: "dtype" value { type: DT_STRING } } 
 attr { key: "shape" value { shape { } } } 
}
node {
 name: "save/SaveV2/tensor_names" 
 op: "Const" 
 attr { key: "_output_shapes" value { list { shape { dim { size: 2 } } } } } 
 attr { key: "dtype" value { type: DT_STRING } } 
 attr { key: "value" value { tensor { dtype: DT_STRING tensor_shape { dim { size: 2 } } string_val: "layer1/W" string_val: "layer2/W" } } } 
}

save/RestoreV2是真正的去读取index文件和data文件的节点,文件名由save/Const提供,这也是一个输入节点,SaveV2/tensor_names提供了需要加载的变量的名称。

3.3 指定输入输出

在example_1/saved_model.pb能够看到signature_def,在example_0/model.meta中找不到这个定义。

代码语言:javascript复制
signature_def {
 key: "default"
 value {
  inputs { key: "x" value { name: "x:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 100 } } } }
  outputs { key: "y" value { name: "layer1/add:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 100 } } } } 
  outputs { key: "z" value { name: "layer2/add:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 100 } } } } 
  method_name: "default" 
 }
}

有了上面的信息,tfserving加载模型的时候就知道了在收到的请求中获取x的赋值,将计算得到的y和z的放到回包中返还给用户。通过saved_mode提供的api我们其实可以任意指定某个节点作为输入或输出,比如指定inputs为空、outputs为layer2/W,这样tfserving加载模型之后就知道不需要请求中有任何赋值,将layer2/W的值放到回包中返还给用户。

0 人点赞