tf.Session

2022-09-04 20:54:52 浏览数 (1)

一个运行TensorFlow操作的类。会话对象封装了执行操作对象和计算张量对象的环境。

例:

代码语言:javascript复制
# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b

# Launch the graph in a session.
sess = tf.Session()

# Evaluate the tensor `c`.
print(sess.run(c))

会话可能拥有资源,比如tf.variable,tf.QueueBase, tf.ReaderBase。当不再需要这些资源时,释放它们是很重要的。为此,可以调用tf.Session。关闭会话上的方法,或将会话用作上下文管理器。下面两个例子是等价的:

代码语言:javascript复制
# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()

# Using the context manager.
with tf.Session() as sess:
  sess.run(...)

ConfigProto协议缓冲区公开会话的各种配置选项。例如,要创建一个使用设备放置软约束的会话,并记录结果的放置决策,创建一个会话如下:

代码语言:javascript复制
# Launch the graph in a session that allows soft device placement and
# logs the placement decisions.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=True))

性质:

graph:

在此会话中启动的图表。

graph_def

底层张量流图的可序列化版本。

返回值:

  • graph_pb2.GraphDef proto包含底层TensorFlow图中所有操作的节点。

sess_str

Methods

__init__

代码语言:javascript复制
__init__(
    target='',
    graph=None,
    config=None
)

创建一个新的TensorFlow会话。如果在构造会话时没有指定图形参数,则会话中将启动缺省图形。如果在同一过程中使用多个图(使用tf.Graph()创建),则必须为每个图使用不同的会话,但是每个图可以在多个会话中使用。在这种情况下,将要显式启动的图形传递给会话构造函数通常更清楚。

参数:

  • target: (可选)。要连接到的执行引擎。默认使用进程内引擎。有关更多示例,请参见分布式TensorFlow。
  • graph: (可选)。将要启动的图表(如上所述)。
  • config: (可选)带有会话配置选项的ConfigProto协议缓冲区。

__enter__

代码语言:javascript复制
__enter__()

__exit__

代码语言:javascript复制
__exit__(
    exec_type,
    exec_value,
    exec_tb
)

as_default

代码语言:javascript复制
as_default()

返回使此对象成为默认会话的上下文管理器。使用with关键字指定对tf.Operation.run或tf.张量的调用。eval应该在这个会话中执行。

代码语言:javascript复制
c = tf.constant(..)
sess = tf.Session()

with sess.as_default():
  assert tf.get_default_session() is sess
  print(c.eval())

要获取当前默认会话,请使用tf.get_default_session。注意:当你退出上下文时,as_default上下文管理器不会关闭会话,您必须显式地关闭会话。

代码语言:javascript复制
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()

或者,你可以使用with tf.Session():创建一个在退出上下文时自动关闭的会话,包括在引发未捕获异常时。

注意:默认会话是当前线程的属性。如果您创建了一个新线程,并且希望在该线程中使用默认会话,则必须在该线程的函数中显式地添加一个带有ses .as_default():的会话。注意:使用ssh .as_default():块输入a不会影响当前默认图。如果您正在使用多个图形,那么sess。图与tf值不同。get_default_graph,您必须显式地输入一个带有sess.graph.as_default():块的参数来执行sess。绘制默认图形。

返回值:

  • 使用此会话作为默认会话的上下文管理器。

close

代码语言:javascript复制
close()

关闭会话。调用此方法释放与会话关联的所有资源。

异常:

  • tf.errors.OpError: Or one of its subclasses if an error occurs while closing the TensorFlow session.

list_devices

代码语言:javascript复制
list_devices()

列出此会话中的可用设备。

代码语言:javascript复制
devices = sess.list_devices()
for d in devices:
  print(d.name)

列表中的每个元素都具有以下属性:- name:一个带有设备全名的字符串。例如:/job:worker/ copy:0/task:3/device:CPU:0 - device_type:设备的类型(例如CPU、GPU、TPU) - memory_limit:设备上可用的最大内存量。注意:根据设备的不同,可用内存可能会大大减少。

异常:

  • tf.errors.OpError: If it encounters an error (e.g. session is in an invalid state, or network errors occur).

返回值:

  • 会话中的设备列表。

make_callable

代码语言:javascript复制
make_callable(
    fetches,
    feed_list=None,
    accept_options=False
)

返回一个运行特定步骤的Python可调用函数。返回的可调用函数将接受len(feed_list)参数,其类型必须与feed_list的各个元素的提要值兼容。例如,如果feed_list的元素i是tf。张量,返回的可调用的第i个参数必须是一个numpy ndarray(或可转换为ndarray的东西),它具有匹配的元素类型和形状。返回的可调用函数将具有与tf.Session.run(fetches,…)相同的返回类型。例如,如果fetches是tf。张量,可调用的将返回一个numpy ndarray;如果fetches是tf。操作,它将返回None。

参数:

  • fetches: 要获取的值或值列表。有关允许获取类型的详细信息,请参见tf.Session.run。
  • feed_list: (可选)。feed_dict键的列表。有关允许的提要键类型的详细信息,请参见tf.Session.run。
  • accept_options:(可选)。如果为真,返回的Callable将能够接受tf。RunOptions和tf。RunMetadata分别作为可选的关键字参数选项和run_metadata,具有与tf.Session.run相同的语法和语义,这对于某些用例(分析和调试)是有用的,但是会导致可调用程序性能的显著下降。默认值:False。

返回值:

  • 调用时将执行feed_list定义的步骤并在此会话中获取的函数。

异常:

  • TypeError: If fetches or feed_list cannot be interpreted as arguments to tf.Session.run.

partial_run

代码语言:javascript复制
partial_run(
    handle,
    fetches,
    feed_dict=None
)

使用更多的提要和获取继续执行。这是实验性的,可能会发生变化。 要使用部分执行,用户首先调用partial_run_setup(),然后调用partial_run()序列。partial_run_setup指定将在后续partial_run调用中使用的提要和获取列表。可选的feed_dict参数允许调用者覆盖图中张量的值。有关更多信息,请参见run()。

下面是一个简单的例子:

代码语言:javascript复制
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(r1, c)

h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
res = sess.partial_run(h, r2, feed_dict={c: res})

参数:

  • handle: 用于一系列部分运行的句柄。
  • fetches: 单个图形元素、一组图形元素或一个字典,其值是图形元素或图形元素列表(请参阅运行文档)。
  • feed_dict:将图形元素映射到值的字典(如上所述)。

返回值:

  • 如果fetches是单个图形元素,则使用单个值;如果fetches是列表,则使用值列表;如果fetches是字典,则使用与之相同的键的字典(有关运行,请参阅文档)。

异常:

  • tf.errors.OpError: Or one of its subclasses on error.

partial_run_setup

代码语言:javascript复制
partial_run_setup(
    fetches,
    feeds=None
)

为部分运行设置一个包含提要和获取的图。这是实验性的,可能会发生变化。注意,与run相反,提要只指定图元素。张量将由后续的partial_run调用提供。

参数:

  • fetches: 单个图元素,或一组图元素。
  • feeds: 单个图元素,或图元素列表。

返回值:

  • 用于部分运行的句柄。

异常:

  • RuntimeError: If this Session is in an invalid state (e.g. has been closed).
  • TypeError: If fetches or feed_dict keys are of an inappropriate type.
  • tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.

reset

代码语言:javascript复制
@staticmethod
reset(
    target,
    containers=None,
    config=None
)

在目标上重置资源容器,并关闭所有连接的会话。资源容器分布在与目标相同的集群中的所有worker上。当重置目标上的资源容器时,将清除与该容器关联的资源。特别是,容器中的所有变量都将成为未定义的:它们将丢失它们的值和形状。

注意:(i) reset()目前仅为分布式会话实现。(ii)按目标命名的关于主机的任何会话都将关闭。如果没有提供资源容器,则重置所有容器。

参数:

  • target: 要连接到的执行引擎。
  • containers: 资源容器名称字符串的列表,如果要重置所有容器,则为None。
  • config: (可选)带有配置选项的协议缓冲区。

异常:

  • tf.errors.OpError: Or one of its subclasses if an error occurs while resetting containers.

run

代码语言:javascript复制
run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

在读取中运行操作并计算张量。该方法运行TensorFlow计算的一个“步骤”,通过运行必要的图片段来执行每一个操作,并在fetches中计算每个张量,用feed_dict中的值替换相应的输入值。fetches参数可以是一个单独的图形元素,也可以是一个任意嵌套的列表、元组、namedtuple、dict或OrderedDict,它的叶子中包含图形元素。图形元素可以是以下类型之一:

  • 一个tf.Operation。对应的获取值将为None。
  • tf.Tensor。相应的获取值将是一个包含该张量值的numpy ndarray。
  • tf.SparseTensor。对应的获取值将是tf。包含稀疏张量的值。
  • 一个get_tensor_handle操作符。相应的获取值将是一个包含该张量句柄的numpy ndarray。
  • 一个字符串,它是图中张量或运算的名称。

run()返回的值具有与fetches参数相同的形状,其中叶子被TensorFlow返回的相应值替换。例:

代码语言:javascript复制
   a = tf.constant([10, 20])
   b = tf.constant([1.0, 2.0])
   # 'fetches' can be a singleton
   v = session.run(a)
   # v is the numpy array [10, 20]
   # 'fetches' can be a list.
   v = session.run([a, b])
   # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
   # 1-D array [1.0, 2.0]
   # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
   MyData = collections.namedtuple('MyData', ['a', 'b'])
   v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
   # v is a dict with
   # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
   # 'b' (the numpy array [1.0, 2.0])
   # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
   # [10, 20].

可选的feed_dict参数允许调用者覆盖图中张量的值。feed_dict中的每个键都可以是以下类型之一:

  • 如果键是tf.Tensor,其值可以是Python标量、字符串、列表或numpy ndarray,可以转换为与该张量相同的dtype。此外,如果键是tf。将检查值的形状是否与占位符兼容。
  • 如果键是tf.Tensorsparse,这个值应该是tf.SparseTensorValue。
  • 如果键是张量或稀疏张量的嵌套元组,则该值应该是嵌套元组,其结构与上面映射到其对应值的结构相同。

feed_dict中的每个值必须转换为对应键的dtype的numpy数组。可选选项参数预期会出现[runo]。这些选项允许控制此特定步骤的行为(例如打开跟踪)。可选的run_metadata参数需要一个[RunMetadata]原型。在适当的时候,这个步骤的非张量输出将被收集到这里。例如,当用户打开跟踪选项时,所分析的信息将被收集到这个参数中并传递回去。

参数:

  • fetches:单个图元素、图元素列表或字典,其值是图元素或图元素列表(如上所述)。
  • feed_dict:将图形元素映射到值的字典(如上所述)。
  • options:[runo]协议缓冲区
  • run_metadata:一个[RunMetadata]协议缓冲区

返回值:

如果fetches是单个图形元素,则使用单个值;如果fetches是列表,则使用值列表;如果fetches是字典,则使用与之相同的键的字典(如上所述)。未定义在调用中计算获取操作的顺序。

异常:

  • RuntimeError: If this Session is in an invalid state (e.g. has been closed).
  • TypeError: If fetches or feed_dict keys are of an inappropriate type.
  • ValueError: If fetches or feed_dict keys are invalid or refer to a Tensor that doesn't exist.

原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/Session?hl=en

0 人点赞