TensorFlow中与model_fn相关的类和方法

2018-04-28 11:10 更新
#版权所有2016年TensorFlow作者.版权所有.
#根据Apache许可证2.0版(“许可证”)获得许可;
#除了符合许可证外,您不得使用此文件.
#您可以在获得许可证副本
#http://www.apache.org/licenses/LICENSE-2.0
#除非适用法律要求或以书面形式同意软件根据许可证分发的#按“现状”分发,
#没有任何形式的保证或条件,无论是明示还是暗示.
#请参阅许可证以了解特定语言的管理权限和权限
#许可证下的限制.
#==============================================================================
“与model_fn相关的类和方法.”
from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import six from tensorflow.python.estimator.export.export_output import ExportOutput from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest class ModeKeys(object): """Standard names for model modes. The following standard keys are defined: * `TRAIN`: training mode. * `EVAL`: evaluation mode. * `PREDICT`: inference mode. """ TRAIN = 'train' EVAL = 'eval' PREDICT = 'infer' LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', 'evaluation_hooks', 'prediction_hooks' ])): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. `EstimatorSpec` fully defines the model to be run by an `Estimator`. """ def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `EstimatorSpec` instance. Depending on the value of `mode`, different arguments are required. Namely * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`. * For `mode == ModeKeys.EVAL`: required field is `loss`. * For `mode == ModeKeys.PREDICT`: required fields are `predictions`. model_fn can populate all arguments independent of mode. In this case, some arguments will be ignored by an `Estimator`. E.g. `train_op` will be ignored in eval and infer modes. Example: ```python def my_model_fn(mode, features, labels): predictions = ... loss = ... train_op = ... return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op) ``` Alternatively, model_fn can just populate the arguments appropriate to the given mode. Example: ```python def my_model_fn(mode, features, labels): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... else: loss = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = ... else: train_op = None if mode == tf.estimator.ModeKeys.PREDICT: predictions = ... else: predictions = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op) ``` Args: mode: A `ModeKeys`. Specifies if this is training, evaluation or prediction. predictions: Predictions `Tensor` or dict of `Tensor`. loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`. train_op: Op for the training step. eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the `update_op` or requires any input fetching. export_outputs: Describes the output signatures to be exported to `SavedModel` and used during serving. A dict `{name: output}` where: * name: An arbitrary name for this output. * output: an `ExportOutput` object such as `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training. training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers during training. scaffold: A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training. evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run during evaluation. prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run during predictions. Returns: A validated `EstimatorSpec` object. Raises: ValueError: If validation fails. TypeError: If any of the arguments is not the expected type. """ # Validate train_op. if train_op is None: if mode == ModeKeys.TRAIN: raise ValueError('Missing train_op.') else: _check_is_tensor_or_operation(train_op, 'train_op') # Validate loss. if loss is None: if mode in (ModeKeys.TRAIN, ModeKeys.EVAL): raise ValueError('Missing loss.') else: loss = _check_is_tensor(loss, 'loss') loss_shape = loss.get_shape() if loss_shape.num_elements() not in (None, 1): raise ValueError('Loss must be scalar, given: {}'.format(loss)) if not loss_shape.is_compatible_with(tensor_shape.scalar()): loss = array_ops.reshape(loss, []) # Validate predictions. if predictions is None: if mode == ModeKeys.PREDICT: raise ValueError('Missing predictions.') predictions = {} else: if isinstance(predictions, dict): predictions = { k: _check_is_tensor(v, 'predictions[{}]'.format(k)) for k, v in six.iteritems(predictions) } else: predictions = _check_is_tensor(predictions, 'predictions') # Validate eval_metric_ops. if eval_metric_ops is None: eval_metric_ops = {} else: if not isinstance(eval_metric_ops, dict): raise TypeError( 'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops)) for key, metric_value_and_update in six.iteritems(eval_metric_ops): if (not isinstance(metric_value_and_update, tuple) or len(metric_value_and_update) != 2): raise TypeError( 'Values of eval_metric_ops must be (metric_value, update_op) ' 'tuples, given: {} for key: {}'.format( metric_value_and_update, key)) metric_value, metric_update = metric_value_and_update for metric_value_member in nest.flatten(metric_value): # Allow (possibly nested) tuples for metric values, but require that # each of them be Tensors or Operations. _check_is_tensor_or_operation(metric_value_member, 'eval_metric_ops[{}]'.format(key)) _check_is_tensor_or_operation(metric_update, 'eval_metric_ops[{}]'.format(key)) # Validate export_outputs. if export_outputs is not None: if not isinstance(export_outputs, dict): raise TypeError('export_outputs must be dict, given: {}'.format( export_outputs)) for v in six.itervalues(export_outputs): if not isinstance(v, ExportOutput): raise TypeError( 'Values in export_outputs must be ExportOutput objects. ' 'Given: {}'.format(export_outputs)) # Note export_outputs is allowed to be empty. if len(export_outputs) == 1: (key, value), = export_outputs.items() if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: export_outputs[ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value if len(export_outputs) > 1: if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in export_outputs): raise ValueError( 'Multiple export_outputs were provided, but none of them is ' 'specified as the default. Do this by naming one of them with ' 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') # Validate that all tensors and ops are from the default graph. default_graph = ops.get_default_graph() # We enumerate possible error causes here to aid in debugging. error_message_template = ( '{0} with "{1}" must be from the default graph. ' 'Possible causes of this error include: \n\n' '1) {0} was created outside the context of the default graph.' '\n\n' '2) The object passed through to EstimatorSpec was not created ' 'in the most recent call to "model_fn".') if isinstance(predictions, dict): for key, value in six.iteritems(predictions): if value.graph is not default_graph: raise ValueError(error_message_template.format( 'prediction values', '{0}: {1}'.format(key, value.name))) elif predictions is not None: # 'predictions' must be a single Tensor. if predictions.graph is not default_graph: raise ValueError(error_message_template.format( 'prediction values', predictions.name)) if loss is not None and loss.graph is not default_graph: raise ValueError(error_message_template.format('loss', loss.name)) if train_op is not None and train_op.graph is not default_graph: raise ValueError(error_message_template.format('train_op', train_op.name)) for key, value in list(six.iteritems(eval_metric_ops)): values = nest.flatten(value) for value in values: if value.graph is not default_graph: raise ValueError(error_message_template.format( 'eval_metric_ops', '{0}: {1}'.format(key, value.name))) # Validate hooks. training_chief_hooks = tuple(training_chief_hooks or []) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) prediction_hooks = tuple(prediction_hooks or []) for hook in (training_hooks + training_chief_hooks + evaluation_hooks + prediction_hooks): if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be SessionRunHook instances, given: {}'.format( hook)) scaffold = scaffold or monitored_session.Scaffold() # Validate scaffold. if not isinstance(scaffold, monitored_session.Scaffold): raise TypeError( 'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold)) return super(EstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, scaffold=scaffold, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks) def _replace(self, **kwds): """Return a new EstimatorSpec replacing specified fields with new values.""" if 'mode' in kwds: if self.mode != kwds['mode']: raise ValueError('mode of EstimatorSpec cannot be changed.') new_fields = map(kwds.pop, self._fields, list(self)) return EstimatorSpec(*new_fields) def _check_is_tensor_or_operation(x, name): if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)): raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x)) def _check_is_tensor(x, tensor_name): """Returns `x` if it is a `Tensor`, raises TypeError otherwise.""" if not isinstance(x, ops.Tensor): raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x)) return x
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号