TensorFlow变量函数:tf.make_template

函数:tf.make_template
make_template(
    name_,
    func_,
    create_scope_now_=False,
    unique_name_=None,
    custom_getter_=None,
    **kwargs
)

定义在:tensorflow/python/ops/template.py

参见指南:变量>共享变量

给定一个任意函数,将其包装,以便它进行变量共享.
这将 func_ 在模板中进行包装,并对其进行部分评估.模板是在第一次被调用时创建变量并在其后重用它们的函数.为了使 func_ 与模板兼容,它必须具有以下属性:

  • 该函数应该创建所有可训练的变量和任何通过调用 tf. get_variable 来重用的变量.如果使用 tf.Variable 创建训练变量,则将抛出一个 ValueError 异常.可以通过指定 tf.Variable(..., trainable=false) 来创建用于局部变量的变数.
  • 该函数可以在内部使用变量范围和其他模板来创建和重用变量,但它不应该使用 tf. global_variables 来捕获在函数范围之外定义的变量.
  • 内部范围和变量名称不应取决于任何未提供给 make_template 的任何参数.一般来说,你会得到一个 ValueError 异常,告诉你,如果你犯了一个错误,你试图重用一个不存在的变量.

在下面的例子中,z 和 w 将用相同的 y 缩放.重要的是要注意,如果我们没有分配 scalar_name 和为 z 和 w 分配不同的名称,ValueError 异常将抛出,因为它不能重用该变量.

def my_op(x, scalar_name):
  var1 = tf.get_variable(scalar_name,
                         shape=[],
                         initializer=tf.constant_initializer(1))
  return x * var1

scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')

z = scale_by_y(input1)
w = scale_by_y(input2)

作为安全防护装置, 如果训练变量是通过调用 tf.Variable 创建的,则返回的函数将在第一次调用后引发 ValueError.

如果所有这些都是真的,则模板将强制执行以下2个属性:

  1. 多次调用相同的模板将共享所有非局部变量.
  2. 两个不同的模板保证是唯一的,除非您重新输入与模板的初始定义相同的变量范围并对其进行重定义,以下是此异常的示例:
    def my_op(x, scalar_name):
      var1 = tf.get_variable(scalar_name,
                             shape=[],
                             initializer=tf.constant_initializer(1))
      return x * var1
    
    with tf.variable_scope('scope') as vs:
      scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
      z = scale_by_y(input1)
      w = scale_by_y(input2)
    
    # Creates a template that reuses the variables above.
    with tf.variable_scope(vs, reuse=True):
      scale_by_y2 = tf.make_template('scale_by_y', my_op, scalar_name='y')
      z2 = scale_by_y2(input1)
      w2 = scale_by_y2(input2)

根据 create_scope_now_ 的值,完全可变范围可能在第一次调用时或构造时捕获.如果此选项设置为 True,则通过对模板的重复调用创建的所有张量将有一个额外的尾部 _ N+1 到它们的名称,因为在模板构造函数中首次输入作用域时,不会创建任何张量.

注:name_、func_ 和 create_scope_now_ 有一个尾部下划线,以减少与 kwargs 冲突的可能性.

参数:

  • name_:此模板创建的范围的名称.如有必要,该名称将通过将 _N 追加到名称中而成为唯一的.
  • func_:要包装的函数.
  • create_scope_now_:布尔控制是否应在构造模板时或调用模板时创建作用域,默认值为 False,表示在调用模板时将创建作用域.
  • unique_name_:使用时,它会覆盖 name_,而不是唯一的.如果已存在同一 scope/unique_name 的模板,并且重用为 false,则会引发错误.默认为 None.
  • custom_getter_:func_ 中使用的变量的可选自定义 getter.有关详细信息,请参阅文档.
  • ** kwargs:要应用于 func_ 的关键字参数.

返回值:

一个函数,用于封装一组应创建并重用的变量.将创建一个封闭作用域,无论在何处调用 make_template,或者在调用结果的地方,这取决于 create_scope_now_ 的值.无论值如何,首次调用该模板时,都将进入该范围,而不重复使用,并调用 func_ 创建变量,这些变量得保证是唯一的.所有后续调用都将重新输入范围并重新使用这些变量.

可能引发的异常:

  • ValueError:如果名称为 None.
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

意见反馈
返回顶部