阅读(3898)
赞(9)
TensorFlow函数教程:tf.nn.dynamic_rnn
2018-12-22 14:24:38 更新
tf.nn.dynamic_rnn函数
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
定义在:tensorflow/python/ops/rnn.py.
请参阅指南:神经网络>递归神经网络
创建由 RNNCellcell
指定的递归神经网络
执行inputs
的完全动态展开.
示例:
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
参数:
cell
:RNNCell的一个实例.inputs
:RNN输入.如果time_major == False
(默认),则是一个shape为[batch_size, max_time, ...]
的Tensor
,或者这些元素的嵌套元组.如果time_major == True
,则是一个shape为[max_time, batch_size, ...]
的Tensor
,或这些元素的嵌套元组.这也可能是满足此属性的Tensors(可能是嵌套的)元组.前两个维度必须匹配所有输入,否则秩和其他形状组件可能不同.在这种情况下,在每个时间步输入到cell
将复制这些元组的结构,时间维度(从中获取时间)除外.在每个时间步输入到个cell
将是一个Tensor
或(可能是嵌套的)Tensors元组,每个元素都有维度[batch_size, ...]
.sequence_length
:(可选)大小为[batch_size]
的int32/int64的向量.超过批处理元素的序列长度时用于复制状态和零输出.所以它更多的是正确性而不是性能.initial_state
:(可选)RNN的初始状态.如果cell.state_size
是整数,则必须是具有适当类型和shape为[batch_size, cell.state_size]
的Tensor
.如果cell.state_size
是一个元组,则应该是张量元组,在cell.state_size
中为s设置shape[batch_size, s]
.dtype
:(可选)初始状态和预期输出的数据类型.如果未提供initial_state或RNN状态具有异构dtype,则是必需的.parallel_iterations
:(默认值:32).并行运行的迭代次数.适用于那些没有任何时间依赖性并且可以并行运行的操作.该参数用于交换空间的时间.远大于1的值会使用更多内存但占用更少时间,而较小值使用较少内存但计算时间较长.swap_memory
:透明地交换推理中产生的张量,但是需要从GPU到CPU的支持.这允许训练通常不适合单个GPU的RNN,具有非常小的(或没有)性能损失.time_major
:inputs
和outputs
Tensor的形状格式.如果是true,则这些Tensors
的shape必须为[max_time, batch_size, depth]
.如果是false,则这些Tensors
的shape必须为[batch_size, max_time, depth]
.使用time_major = True
更有效,因为它避免了RNN计算开始和结束时的转置.但是,大多数TensorFlow数据都是batch-major,因此默认情况下,此函数接受输入并以batch-major形式发出输出.scope
:用于创建子图的VariableScope;默认为“rnn”.
返回:
一对(outputs, state),其中:
-
outputs
:RNN输出Tensor
.如果time_major == False(默认),这将是shape为
[batch_size, max_time, cell.output_size]
的Tensor
.如果time_major == True,这将是shape为
[max_time, batch_size, cell.output_size]
的Tensor
.注意,如果
cell.output_size
是整数或TensorShape
对象的(可能是嵌套的)元组,那么outputs
将是一个与cell.output_size
具有相同结构的元祖,它包含与cell.output_size
中的形状数据有对应shape的Tensors. -
state
:最终的状态.如果cell.state_size
是int,则会形成[batch_size, cell.state_size]
.如果它是TensorShape
,则将形成[batch_size] + cell.state_size
.如果它是一个(可能是嵌套的)int或TensorShape
元组,那么这将是一个具有相应shape的元组.如果单元格是LSTMCells
,则state
将是包含每个单元格的LSTMStateTuple
的元组.
可能引发的异常:
TypeError
:如果cell
不是RNNCell的实例.ValueError
:如果输入为None或是空列表.