tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np的运算,然后返回tensor这样可以加强tensorflow的灵活性。在目标检测算法Faster R-CNN中,需要计算各种ground truth,接口比较复杂。因此,使用tf.py_func是一个比较好的途径。对于tf.py_func的使用,可以参见计算RPN的ground truth和计算proposals的ground truth时的使用方法。可以看到,都是将tensor转化成numpy array,再使用np.操作完成复杂运算。封装一个python函数并将其用作TensorFlow op。
代码语言:javascript复制tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
参数:
- func: 一个Python函数,它接受ndarray对象作为参数并返回一个ndarray对象列表(或单个ndarray)。这个函数必须接受inp中有多少张量就有多少个参数,这些参数类型将匹配相应的tf.inp中的tf.tensor。返回的ndarrays必须匹配已定义的Tout的数字和类型。重要提示: func的输入和输出numpy ndarrays不能保证是副本。在某些情况下,它们的底层内存将与相应的TensorFlow张量共享。在没有显式(np.)复制的python数据结构中,就地修改或存储func输入或返回值可能会产生不确定的结果。
- inp: 一个张量对象的列表。
- Tout: tensorflow数据类型的列表或元组,如果只有一个tensorflow数据类型,则使用单个tensorflow数据类型,指示func返回什么。
- stateful: (布尔)。如果为真,则应该认为该函数是有状态的。如果一个函数是无状态的,当给定相同的输入时,它将返回相同的输出,并且没有可观察到的副作用。诸如公共子表达式消除之类的优化只在无状态操作上执行。
- name: 操作的名称(可选)。
返回值:
- func计算的张量或单个张量的列表。
例:
代码语言:javascript复制def my_func(array1,array2):
return array1 array2, array1 - array2
if __name__ =='__main__':
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
a1 = tf.placeholder(tf.float32,[2,2],name = 'array1')
a2 = tf.placeholder(tf.float32,[2,2],name = 'array2')
y1,y2 = tf.py_func(my_func,[a1,a2],[tf.float32, tf.float32])
with tf.Session() as sess:
y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2})
print(y1_)
print('*'*10)
print(y2_)
Output:
-----------
[[2. 4.]
[6. 8.]]
**********
[[0. 0.]
[0. 0.]]
-----------
直接用array的方式操作:
代码语言:javascript复制import tensorflow as tf
import numpy as np
def my_func(array1,array2):
return array1 array2, array1 - array2
with tf.Session() as sess:
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
y1, y2 = my_func(array1, array2)
print(y1)
print('*' * 10)
print(y2)
Output:
-----------
[[2 4]
[6 8]]
**********
[[0 0]
[0 0]]
-----------
原链接:https://tensorflow.google.cn/api_docs/python/tf/py_func?hl=en