TensorFlow函数教程:tf.nn.static_rnn

tf.nn.static_rnn函数

别名:

  • tf.contrib.rnn.static_rnn
  • tf.nn.static_rnn
tf.nn.static_rnn(
    cell,
    inputs,
    initial_state=None,
    dtype=None,
    sequence_length=None,
    scope=None
)

定义在:tensorflow/python/ops/rnn.py。

创建由RNNCell cell指定的循环神经网络。

生成的最简单的RNN网络形式是:

  state = cell.zero_state(...)
  outputs = []
  for input_ in inputs:
    output, state = cell(input_, state)
    outputs.append(output)
  return (outputs, state)

但是,还有一些其他选项:

可以提供初始状态。如果提供sequence_length向量,则执行动态计算。这种计算方法不计算超过最小批处理的最大序列长度的RNN步骤(从而节省计算时间),并且将示例的序列长度的状态适当地传播到最终状态输出。

在批处理行b的时间t上执行的动态计算:

  (output, state)(b, t) =
    (t >= sequence_length(b))
      ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
      : cell(input(b, t), state(b, t - 1))

参数:

  • cell:RNNCell的一个实例。
  • inputs:输入的长度为T的列表,每个Tensor具有shape [batch_size, input_size];或这些元素的嵌套元组。
  • initial_state:(可选)RNN的初始状态。如果cell.state_size是整数,则必须是具有适当的类型和shape为[batch_size, cell.state_size]的Tensor。如果cell.state_size是一个元组,这应该是具有shape [batch_size, s]的张量元组,其中s位于cell.state_size。
  • dtype:(可选)初始状态和预期输出的数据类型。如果未提供initial_state或RNN状态具有异构类型,则为必需。
  • sequence_length:指定输入中每个序列的长度。int32或int64向量(张量),大小为[batch_size],值位于[0, T)。
  • scope:用于创建子图的VariableScope;默认为“rnn”。

返回:

(outputs, state)对,其中:

  • outputs的长度为T的列表(每个输入一个),或这些元素的嵌套元组。
  • state是最终状态

可能引发的异常:

  • TypeError:如果cell不是RNNCell的实例。
  • ValueError:如果inputs为None或是一个空列表,或者无法通过形状推断从输入推断输入深度(列大小)。
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

意见反馈
返回顶部