tf.py_func()

2022-09-04 21:39:19 浏览数 (1)

tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np的运算,然后返回tensor这样可以加强tensorflow的灵活性。在目标检测算法Faster R-CNN中,需要计算各种ground truth,接口比较复杂。因此,使用tf.py_func是一个比较好的途径。对于tf.py_func的使用,可以参见计算RPN的ground truth和计算proposals的ground truth时的使用方法。可以看到,都是将tensor转化成numpy array,再使用np.操作完成复杂运算。封装一个python函数并将其用作TensorFlow op。

代码语言:javascript复制
tf.py_func(
    func,
    inp,
    Tout,
    stateful=True,
    name=None
)

参数:

  • func: 一个Python函数,它接受ndarray对象作为参数并返回一个ndarray对象列表(或单个ndarray)。这个函数必须接受inp中有多少张量就有多少个参数,这些参数类型将匹配相应的tf.inp中的tf.tensor。返回的ndarrays必须匹配已定义的Tout的数字和类型。重要提示: func的输入和输出numpy ndarrays不能保证是副本。在某些情况下,它们的底层内存将与相应的TensorFlow张量共享。在没有显式(np.)复制的python数据结构中,就地修改或存储func输入或返回值可能会产生不确定的结果。
  • inp: 一个张量对象的列表。
  • Tout: tensorflow数据类型的列表或元组,如果只有一个tensorflow数据类型,则使用单个tensorflow数据类型,指示func返回什么。
  • stateful: (布尔)。如果为真,则应该认为该函数是有状态的。如果一个函数是无状态的,当给定相同的输入时,它将返回相同的输出,并且没有可观察到的副作用。诸如公共子表达式消除之类的优化只在无状态操作上执行。
  • name: 操作的名称(可选)。

返回值:

  • func计算的张量或单个张量的列表。

例:

代码语言:javascript复制
def my_func(array1,array2):
    return array1   array2, array1 - array2

if __name__ =='__main__':
    array1 = np.array([[1, 2], [3, 4]])
    array2 = np.array([[1, 2], [3, 4]])

    a1 = tf.placeholder(tf.float32,[2,2],name = 'array1')
    a2 = tf.placeholder(tf.float32,[2,2],name = 'array2')
    y1,y2 = tf.py_func(my_func,[a1,a2],[tf.float32, tf.float32])

    with tf.Session() as sess:
        y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2})
        print(y1_)
        print('*'*10)
        print(y2_)


Output:
-----------
[[2. 4.]
[6. 8.]]
**********
[[0. 0.]
[0. 0.]]
-----------

直接用array的方式操作:

代码语言:javascript复制
import tensorflow as tf
import numpy as np

def my_func(array1,array2):
    return array1   array2, array1 - array2

with tf.Session() as sess:
  array1 = np.array([[1, 2], [3, 4]])
  array2 = np.array([[1, 2], [3, 4]])
  y1, y2 = my_func(array1, array2)
  print(y1)
  print('*' * 10)
  print(y2)


Output:
-----------
[[2 4]
 [6 8]]
**********
[[0 0]
 [0 0]]
-----------

原链接:https://tensorflow.google.cn/api_docs/python/tf/py_func?hl=en

0 人点赞