TensorFlow函数:tf.estimator.EstimatorSpec

2018-04-27 14:06 更新

tf.estimator.EstimatorSpec函数

EstimatorSpec类

定义在:tensorflow/python/estimator/model_fn.py.

从model_fn返回的操作和对象并传递给Estimator.

EstimatorSpec完全定义了由Estimator运行的模型.

属性

  • eval_metric_ops
    字段号4的别名
  • evaluation_hooks
    字段号9的别名
  • export_outputs
    字段号5的别名
  • loss
    字段号2的别名
  • mode
    字段号0的别名
  • prediction_hooks
    字段号10的别名
  • predictions
    字段号1的别名
  • scaffold
    字段号8的别名
  • train_op
    字段号3的别名
  • training_chief_hooks
    字段号6的别名
  • training_hooks
    字段号7的别名

方法

__new__

@ staticmethod 
__new__ ( 
    cls , 
    mode , 
    predictions = None , 
    loss = None , 
    train_op = None , 
    eval_metric_ops = None , 
    export_outputs = None , 
    training_chief_hooks = None , 
    training_hooks = None , 
    scaffold = None , 
    evaluation_hooks = None , 
    prediction_hooks= 无
)

创建一个已经验证的EstimatorSpec实例.

根据mode的值的不同,需要不同的参数,即:

  • 对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
  • 对于mode == ModeKeys.EVAL:必填字段是loss.
  • 为mode == ModeKeys.PREDICT:必填字段是predictions.

model_fn可以填充独立于模式的所有参数.在这种情况下,Estimator将忽略某些参数.在eval和infer模式中,train_op将被忽略.例子如下:

def my_model_fn(mode, features, labels):
  predictions = ...
  loss = ...
  train_op = ...
  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

或者,model_fn可以填充适合给定模式的参数.例:

def my_model_fn(mode, features, labels):
  if (mode == tf.estimator.ModeKeys.TRAIN or
      mode == tf.estimator.ModeKeys.EVAL):
    loss = ...
  else:
    loss = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    train_op = ...
  else:
    train_op = None
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = ...
  else:
    predictions = None

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

函数参数:

  • mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
  • predictions:预测Tensor或字典Tensor.
  • loss:训练损失Tensor,必须是标量或形状[1].
  • train_op:适用于训练的步骤.
  • eval_metric_ops:按名称键入的度量结果字典.字典的值是调用度量函数的结果,即(metric_tensor, update_op)元组.应该在没有任何状态影响的情况下进行metric_tensor计算(通常是基于变量的纯计算结果).例如,它不应该触发update_op或需要任何输入提取.
  • export_outputs:描述要在服务期间导出到SavedModel并使用的输出签名.在字典{name: output}中:name:此输出的任意名称.output:一个ExportOutput对象,如ClassificationOutput,RegressionOutput或PredictOutput.Single-headed模型只需要在本字典中指定一个条目.Multi-headed模型应为每个头指定一个条目,其中之一必须使用signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY进行命名.
  • training_chief_hooks:在训练期间可以在主要工作人员中运行的tf.train.SessionRunHook对象的迭代.
  • training_hooks:在训练过程中可以对所有工作人员运行的tf.train.SessionRunHook对象.
  • scaffold:可用于设置初始化,保护程序等用于训练的tf.train.Scaffold对象.
  • evaluation_hooks:评估期间要运行的tf.train.SessionRunHook对象的可迭代性.
  • prediction_hooks:在预测期间可以运行的tf.train.SessionRunHook对象的可迭代性.

返回值:

一个经过验证的EstimatorSpec对象.

可能引发的异常:

  • ValueError:如果验证失败,则会引发此异常.
  • TypeError:如果任何参数不是预期的类型.
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号