TensorFlow处理RNN参数变量

由 Carrie 创建, 最后一次修改 2017-09-01

tf.contrib.cudnn_rnn.RNNParamsSaveable


tf.contrib.cudnn_rnn.RNNParamsSaveable 类

定义在:tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py

用于处理 RNN 参数变量的 SaveableObject 实现.

方法


__init__

__init__ (
params_to_canonical ,
canonical_to_params ,
param_variables ,
name = 'params_canonical'
)

创建一个 RNNParamsSaveable 对象.

RNNParams 可以在检查点文件中保存/恢复,用于以规范格式保存/恢复权重和偏置参数,其中参数逐层保存为张量.对于每个层,偏差张量在重量张量之后被保存.恢复时,用户可以根据需要命名 param_variables,并将权重和偏差张量恢复到这些变量.

对于 CudnnRNNRelu 或 CudnnRNNTanh,每个层的每个权重和每个偏移量都有两个张量:张量0被用于从前一层输入,张量1用于循环输入.

对于 CudnnLSTM,每个层的每个权重和每个偏移量有8个张量;张量0-3被用于从前一层输入;张量4-7用于循环输入;张量0和4用于输入门;张量1和5忘记门;张量2和6新的存储门; 张量3和7是输出门.

对于 CudnnGRU,每个层的每个权重和每个偏移量有6张张量;张量0-2被用于从前一层输入;张量3-5用于循环输入;张量0和3用于复位门;张量1和4更新门;张量2和5新的存储门.

ARGS:

  • params_to_canonical:一种函数, 用于将参数从特定格式转换为 cuDNN 或其他 RNN ops 转换到规范格式._CudnnRNN params_to_canonical () 应在这里提供.
  • canonical_to_params:用于将参数从规范格式转换为 cuDNN 或其他 RNN ops 的特定格式的函数.函数必须返回一个标量 (如 cuDNN) 或元组.此函数可以是 _CudnnRN.
  • param_variables:特定窗体中参数的变量列表.对于 cuDNN RNN ops,这是一个单一的加权和偏见合并变量;对于其他 RNN ops, 这可能是多个未或部分合并的变量, 分别用于权重和偏差.
  • name:RNNParamsSaveable 对象的名称.

restore

restore(
restored_tensors ,
restored_shapes
)


以上内容是否对您有帮助:

二维码
建议反馈
二维码