tensorflow2.0的函数签名与图结构(推荐)

2020-11-02 16:47:19 浏览数 (1)

input_signature的好处:

1.可以限定函数的输入类型,以防止调用函数时调错,

2.一个函数有了input_signature之后,在tensorflow里边才可以保存成savedmodel。在保存成savedmodel的过程中,需要使用get_concrete_function函数把一个tf.function标注的普通的python函数变成带有图定义的函数。

下面的代码具体体现了input_signature可以限定函数的输入类型这一作用。

代码语言:javascript复制
@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])
def cube(z): #实现输入的立方
 return tf.pow(z, 3)
try:
 print(cube(tf.constant([1., 2., 3.])))
except ValueError as ex:
 print(ex)
print(cube(tf.constant([1, 2, 3])))

输出:

Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=’x’)) tf.Tensor([ 1 8 27], shape=(3,), dtype=int32)

get_concrete_function的使用

note:首先说明,下面介绍的函数在模型构建、模型训练的过程中不会用到,下面介绍的函数主要用在两个地方:1、如何保存模型 2、保存好模型后,如何载入进来。

可以给 由@tf.function标注的普通的python函数,给它加上input_signature, 从而让这个python函数变成一个可以保存的tensorflow图结构(SavedModel)

举例说明函数的用法:

代码语言:javascript复制
@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])
def cube(z):
 return tf.pow(z, 3)
 
try:
 print(cube(tf.constant([1., 2., 3.])))
except ValueError as ex:
 print(ex)
 
print(cube(tf.constant([1, 2, 3])))
 
# @tf.function py func -  tf graph
# get_concrete_function -  add input signature -  SavedModel
 
cube_func_int32 = cube.get_concrete_function(
 tf.TensorSpec([None], tf.int32)) #tensorflow的类型
print(cube_func_int32)

输出:

<tensorflow.python.eager.function.ConcreteFunction object at 0x00000240E29695C0

从输出结果可以看到:调用get_concrete_function函数后,输出的是一个ConcreteFunction对象

代码语言:javascript复制
#看用新参数获得的对象与原来的对象是否一样
print(cube_func_int32 is cube.get_concrete_function(
 tf.TensorSpec([5], tf.int32))) #输入大小为5
print(cube_func_int32 is cube.get_concrete_function(
 tf.constant([1, 2, 3]))) #传具体数据

输出:

True True

cube_func_int32.graph #图定义

输出:

代码语言:javascript复制
[<tf.Operation 'x' type=Placeholder ,
 <tf.Operation 'Pow/y' type=Const ,
 <tf.Operation 'Pow' type=Pow ,
 <tf.Operation 'Identity' type=Identity ]
代码语言:javascript复制
pow_op = cube_func_int32.graph.get_operations()[2]
print(pow_op)

输出:

name: “Pow” op: “Pow” input: “x” input: “Pow/y” attr { key: “T” value { type: DT_INT32 } }

代码语言:javascript复制
print(list(pow_op.inputs))
print(list(pow_op.outputs))

输出:

[<tf.Tensor ‘x:0’ shape=(None,) dtype=int32 , <tf.Tensor ‘Pow/y:0’ shape=() dtype=int32 ] [<tf.Tensor ‘Pow:0’ shape=(None,) dtype=int32 ]

cube_func_int32.graph.get_operation_by_name(“x”)

输出:

<tf.Operation ‘x’ type=Placeholder

cube_func_int32.graph.get_tensor_by_name(“x:0”) #默认加“:0”

<tf.Tensor ‘x:0’ shape=(None,) dtype=int32

cube_func_int32.graph.as_graph_def() #总名字,针对上面两个

代码语言:javascript复制
node {
 name: "x"
 op: "Placeholder"
 attr {
 key: "_user_specified_name"
 value {
 s: "x"
 }
 }
 attr {
 key: "dtype"
 value {
 type: DT_INT32
 }
 }
 attr {
 key: "shape"
 value {
 shape {
 dim {
  size: -1
 }
 }
 }
 }
}
node {
 name: "Pow/y"
 op: "Const"
 attr {
 key: "dtype"
 value {
 type: DT_INT32
 }
 }
 attr {
 key: "value"
 value {
 tensor {
 dtype: DT_INT32
 tensor_shape {
 }
 int_val: 3
 }
 }
 }
}
node {
 name: "Pow"
 op: "Pow"
 input: "x"
 input: "Pow/y"
 attr {
 key: "T"
 value {
 type: DT_INT32
 }
 }
}
node {
 name: "Identity"
 op: "Identity"
 input: "Pow"
 attr {
 key: "T"
 value {
 type: DT_INT32
 }
 }
}
versions {
 producer: 119
}

到此这篇关于tensorflow2.0的函数签名与图结构的文章就介绍到这了,更多相关tensorflow函数签名与图结构内容请搜索ZaLou.Cn以前的文章或继续浏览下面的相关文章希望大家以后多多支持ZaLou.Cn!

0 人点赞