contrib.cudnn_rnn.RNNParamsSaveable

tf.contrib.cudnn_rnn.RNNParamsSaveable

class tf.contrib.cudnn_rnn.RNNParamsSaveable

Defined in tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py.

SaveableObject implementation that handles the RNN params variable.

Methods

__init__

__init__(
    params_to_canonical,
    canonical_to_params,
    param_variables,
    name='params_canonical'
)

Creates a RNNParamsSaveable object.

RNNParamsSaveable is saveable/restorable in a checkpoint file and is used to save/restore the weights and biases parameters in a canonical format, where parameters are saved as tensors layer by layer. For each layer, the bias tensors are saved following the weight tensors. When restoring, a user could name param_variables as desired, and restore weight and bias tensors to these variables.

For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per bias for each layer: tensor 0 is applied to the input from the previous layer and tensor 1 to the recurrent input.

For CudnnLSTM, there are 8 tensors per weight and per bias for each layer: tensor 0-3 are applied to the input from the previous layer and tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; tensor 3 and 7 the output gate.

For CudnnGRU, there are 6 tensors per weight and per bias for each layer: tensor 0-2 are applied to the input from the previous layer and tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate.

Args:

  • params_to_canonical: a function to convert params from a specific format for cuDNN or other RNN ops to the canonical format. _CudnnRNN.params_to_canonical() should be provided here.
  • canonical_to_params: a function to convert params from the canonical format to a specific format for cuDNN or other RNN ops. The function must return a scalar (e.g. in the case of cuDNN) or a tuple. This function could be _CudnnRNN.canonical_to_params() or a user-defined function.
  • param_variables: a list of Variables for parameters in a specific form. For cuDNN RNN ops, this is a single merged variable for both weights and biases; for other RNN ops, this might be multiple unmerged or partially merged variables respectively for weights and biases.
  • name: the name of the RNNParamsSaveable object.

restore

restore(
    restored_tensors,
    restored_shapes
)

© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/RNNParamsSaveable

在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

意见反馈
返回顶部