代码语言:javascript复制
tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
封装一个python函数并将其用作TensorFlow op。给定一个python函数func,它以numpy数组作为参数并返回numpy数组作为输出,将这个函数包装为张量流图中的一个操作。下面的代码片段构造了一个简单的TensorFlow图,它调用np.sinh() NumPy函数作为图中的操作:
代码语言:javascript复制def my_func(x):
# x will be a numpy array with the contents of the placeholder below
return np.sinh(x)
input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)
注意:tf.py_func()操作有以下已知的限制:
- 函数体(即func)不会在GraphDef中序列化。因此,如果需要序列化模型并在不同的环境中恢复模型,则不应使用此函数。
- 该操作必须在与调用tf.py_func()的Python程序相同的地址空间中运行。如果使用分布式TensorFlow,则必须运行tf.train。服务器与调用tf.py_func()的程序处于相同的进程中,您必须将创建的操作固定到该服务器中的设备上(例如,使用tf.device():)。
参数:
- func: 一个Python函数,它接受ndarray对象作为参数并返回一个ndarray对象列表(或单个ndarray)。这个函数必须接受inp中有多少张量就有多少个参数,这些参数类型将匹配相应的tf。inp中的张量对象。返回的ndarrays必须匹配已定义的Tout的数字和类型。重要提示:func的输入和输出numpy ndarrays不能保证是副本。在某些情况下,它们的底层内存将与相应的TensorFlow张量共享。就地修改或在py中存储func输入或返回值。
- inp: 一个张量对象的列表。
- Tout: tensorflow数据类型的列表或元组,如果只有一个tensorflow数据类型,则使用单个tensorflow数据类型,指示func返回什么。
stateful
: (布尔)。如果为真,则应该认为该函数是有状态的。如果一个函数是无状态的,当给定相同的输入时,它将返回相同的输出,并且没有可观察到的副作用。诸如公共子表达式消除之类的优化只在无状态操作上执行。- name: 操作的名称(可选)。
返回值:
- func计算的张量或单个张量的列表。
原链接: https://tensorflow.google.cn/versions/r1.10/api_docs/python/tf/py_func?hl=en