tf.estimator.EstimatorSpec
tf.estimator.EstimatorSpec
class tf.estimator.EstimatorSpec
Defined in tensorflow/python/estimator/model_fn.py
.
Ops and objects returned from a model_fn
and passed to Estimator
.
EstimatorSpec
fully defines the model to be run by Estimator
.
Properties
eval_metric_ops
Alias for field number 3
export_outputs
Alias for field number 4
loss
Alias for field number 1
predictions
Alias for field number 0
scaffold
Alias for field number 7
train_op
Alias for field number 2
training_chief_hooks
Alias for field number 5
training_hooks
Alias for field number 6
Methods
__new__
__new__( cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None )
Creates a validated EstimatorSpec
instance.
Depending on the value of mode
, different arguments are required. Namely For mode == ModeKeys.TRAIN
: required fields are loss
and train_op
. For mode == ModeKeys.EVAL
: required field isloss
. * For mode == ModeKeys.PREDICT
: required fields are predictions
.
model_fn can populate all arguments independent of mode. In this case, some arguments will be ignored by Estimator
. E.g. train_op
will be ignored in eval and infer modes. Example:
def my_model_fn(mode, features, labels): predictions = ... loss = ... train_op = ... return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op)
Alternatively, model_fn can just populate the arguments appropriate to the given mode. Example:
def my_model_fn(mode, features, labels): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... else: loss = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = ... else: train_op = None if mode == tf.estimator.ModeKeys.PREDICT: predictions = ... else: predictions = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op)
Args:
-
mode
: AModeKeys
. Specifies if this is training, evaluation or prediction. -
predictions
: PredictionsTensor
or dict ofTensor
. -
loss
: Training lossTensor
. Must be either scalar, or with shape[1]
. -
train_op
: Op for the training step. -
eval_metric_ops
: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a(metric_tensor, update_op)
tuple. -
export_outputs
: Describes the output signatures to be exported to
SavedModel
and used during serving. A dict{name: output}
where:- name: An arbitrary name for this output.
- output: an
ExportOutput
object such asClassificationOutput
,RegressionOutput
, orPredictOutput
. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
-
training_chief_hooks
: Iterable oftf.train.SessionRunHook
objects to run on the chief worker during training. -
training_hooks
: Iterable oftf.train.SessionRunHook
objects that to run on all workers during training. -
scaffold
: Atf.train.Scaffold
object that can be used to set initialization, saver, and more to be used in training.
Returns:
A validated EstimatorSpec
object.
Raises:
-
ValueError
: If validation fails. -
TypeError
: If any of the arguments is not the expected type.
© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec