TensorFlow函数:tf.while_loop

tf.while_loop函数

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    name=None,
    maximum_iterations=None
)

定义在:tensorflow/python/ops/control_flow_ops.py.

请参阅指南:控制流程>控制流程操作

在条件cond成立时重复body.

cond是可返回的布尔标量张量;body是一个可调用的函数,它返回一个(可能是嵌套的)元组、namedtuple或者与loop_vars具有相同arity(长度和结构)和类型的张量列表;loop_vars是一个(可能是嵌套的)元组,namedtuple或被传递给cond和body的张量的列表.cond和body的参数都尽可能的和loop_vars一样多.

除了常规的Tensors或IndexedSlices,body可以接受并返回TensorArray对象.TensorArray对象的流将在循环之间和梯度计算期间适当地转发.

请注意,while_loop只调用cond和body一次(调用while_loop内,而不是在所有Session.run()期间).while_loop将在cond和body调用期间创建的图片段拼接在一起,并添加一些额外的图形节点,以创建重复body的图形流,直到cond返回false为止.

为了正确性,tf.while_loop()严格执行循环变量的形状不变量.形状不变是一个(可能是部分)形状,在整个迭代循环中不变.如果迭代之后的循环变量的形状被确定为比形状不变性更一般或不兼容,则会引发错误.例如,[11,None]的形状比[11,17]的形状更普遍,[11,21]与[11,17]不兼容.默认情况下(如果没有指定shape_invariants参数),则假设每个迭代中的loop_vars中的每个张量的初始形状是相同的.该shape_invariants参数允许调用者为每个循环变量指定一个不太具体的形状不变量,如果形状在迭代之间变化,则需要该变量.该tf.Tensor.set_shape函数也可以用在body函数中来指示输出循环变量具有特定的形状.

SparseTensor和IndexedSlices的形状不变的定义如下:

a)如果循环变量是SparseTensor,则形状不变量必须是TensorShape([r]),其中r是由稀疏张量表示的稠密张量的秩.这意味着SparseTensor的三个张量的形状是([None],[None,r],[r]).注意:此处不变的形状是SparseTensor.dense_shape属性的形状.它必须是矢量的形状.

b)如果循环变量是IndexedSlices,则形状不变量必须是IndexedSlices的值张量的形状不变量.这意味着IndexedSlices的三个张量的形状是(shape,[shape [0]],[shape.ndims]).

while_loop实现非严格的语义,允许多个迭代并行运行.并行迭代的最大数量可以通过parallel_iterations控制,这使用户可以控制内存消耗和执行顺序.对于正确的程序,while_loop应该为任何parallel_iterations>0返回相同的结果.

对于训练,TensorFlow存储在正向推断中生成的张量,并且需要反向传播.这些张量是内存消耗的主要来源,并且在GPU上训练时经常会导致OOM错误.当标志swap_memory为true时,我们将这些张量从GPU交换到CPU.例如,这允许我们训练具有很长序列和大批量的RNN模型.

函数参数:

  • cond:代表循环终止条件的可调用对象.
  • body:代表循环体的可调用对象.
  • loop_vars:一个(可能是嵌套的)元组,namedtuple或numpy数组、Tensor以及TensorArray对象的列表.
  • shape_invariants:循环变量的形状不变量.
  • parallel_iterations:允许并行运行的迭代次数.它必须是一个正整数.
  • back_prop:表示是否为此while循环启用backprop.
  • swap_memory:此循环是否启用GPU-CPU内存交换.
  • name:返回张量的可选名称前缀.
  • maximum_iterations:要运行的while循环的可选最大迭代次数.如果提供了,则cond输出将与附加条件进行AND运算,以确保执行的迭代次数不超过maximum_iterations.要运行的 while 循环的最大迭代次数.

返回值:

执行循环后循环变量的输出张量.当loop_vars的长度为1时,这是一个Tensor、TensorArray或IndexedSlice,当loop_vars的长度大于1时,它返回一个列表.

可能引发的异常:

  • TypeError:如果cond或者body是不可调用的.
  • ValueError:如果loop_vars是空的.

使用示例:

i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])

嵌套和namedtuple的示例:

import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)

使用shape_invariants的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

意见反馈
返回顶部