TensorFlow的estimator类函数:tf.estimator.Estimator

2018-04-27 09:55 更新

tf.estimator.Estimator函数

Estimator类

定义在:tensorflow/python/estimator/estimator.py

estimator类对TensorFlow模型进行训练和计算.

Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作.

所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录.

可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点.

该params参数包含hyperparameter,如果model_fn有一个名为“PARAMS”的参数,并且以相同的方式传递给输入函数,则将它传递给 model_fn.Estimator只是沿着参数传递,并不检查它.因此,params的结构完全取决于开发人员.

不能在子类中重写任何Estimator方法(其构造函数强制执行此操作).子类应使用model_fn来配置基类,并且可以添加实现专门功能的方法.

Eager兼容性

estimator与eager执行不兼容.

属性

  • config
  • model_dir
  • model_fn
    返回绑定到self.params的model_fn.
    返回:返回具有以下签名的model_fn: def model_fn(features, labels, mode, config)
  • params

方法

__init__

__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None
)

构造一个Estimator实例.

请参阅Estimator了解更多信息.启动一个Estimator的方法如下所示:

estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")

有关warm-start启动配置的更多详细信息,请参阅WarmStartSettings.

参数:

  • model_fn:模型函数,具有以下签名:
      ARGS.
    • features:这是从input_fn传递给train、evaluate和predict返回的第一个项目.这应该是一个相同的单一的Tensor或dict.
    • labels:这是从input_fn传递给train、evaluate和predict返回的第二个项目.这应该是相同的单个Tensor或dict(对于multi-head模型).如果模式是ModeKeys.PREDICT,则将传递labels=None.如果model_fn签名不接受mode,model_fn必须仍然能够处理labels=None.
    • mode:可选的.指定train、evaluate和predict.参考ModeKeys.
    • params:hyperparameters的可选字典.将在params参数中接收传递给Estimator的内容.这允许从hyperparameters调整来配置Estimator.
    • config:可选配置对象.将收到传递给Estimator的config参数或默认值config.允许根据配置(如num_ps_replicas或model_dir)更新您的model_fn中的内容.
    • 返回:EstimatorSpec
  • model_dir:保存模型参数、图形等的目录.这也可用于将目录中的检查点加载到Estimator中,以继续训练以前保存的模型.如果为PathLike对象,则路径将被解析.如果为None,则将使用config中的model_dir(如果设置的话).如果两者都设置,则它们必须相同.如果两者都是None,则会使用临时目录.
  • config:配置对象.
  • params:dict将传递到model_fn中的hyperparameters.key是参数的名称,value是基本的Python类型.
  • warm_start_from:可选的字符串文件路径,用于从warm-start的检查点;或tf.estimator.WarmStartSettings对象,用于完全配置warm-start.如果提供字符串文件路径而不是WarmStartSettings,则所有变量都是warm-start的,并且假定词汇表和张量名称未更改.

可能引发的异常:

  • RuntimeError:如果eager执行已启用.
  • ValueError:参数model_fn不匹配params.
  • ValueError:如果这是通过子类调用的,并且该类重写了Estimator的一个成员.

evaluate

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

计算给定计算数据input_fn的模型.

对于每个步骤来说,调用input_fn返回一批数据.计算直到: -steps批处理被处理,或-input_fn引发输入结束异常(OutOfRangeError或StopIteration).

参数:

  • input_fn:构造用于计算的输入数据的函数.有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列选项之一:
    • tf.data.Dataset对象:Dataset对象的输出必须是一个具有相同约束的元组(特征(features),标签(labels)),其约束条件与下面相同.
    • tuple (features, labels):其中features是Tensor或者名为Tensor的字符串特征的字典,而labels是Tensor或者名为Tensor的字符串标签的字典.这两个特征和标签都由model_fn消耗.他们应该满足model_fn对输入的期望.
  • steps:计算模型所需的步骤数.如果为None,则计算直到input_fn引发输入异常时结束.
  • hooks:SessionRunHook子类实例列表.用于计算调用中的回调.
  • checkpoint_path:计算特定检查点的路径.如果为None,则使用model_dir中的最新检查点.
  • name:需要使用的计算的名称,如果用户需要在不同的数据集上运行多个计算(如培训数据和测试数据).不同计算的度量标准保存在单独的文件夹中,并单独出现在tensorboard中.

返回值:

返回一个包含按name为键的model_fn中指定的计算指标的词典,以及包含执行此技术的全局步骤的值的条目global_step.

可能引发的异常:

  • ValueError:如果steps <= 0.
  • ValueError:如果没有模型被训练,名为model_dir,或者给定checkpoint_path是空的.

export_savedmodel

export_savedmodel(
    export_dir_base,
    serving_input_receiver_fn,
    assets_extra=None,
    as_text=False,
    checkpoint_path=None,
    strip_default_attrs=False
)

将推理图作为SavedModel导出到给定的目录中.

该方法通过首先调用serving_input_receiver_fn来获取特征Tensors来构建一个新图,然后调用这个Estimator的model_fn来基于这些特征生成模型图.它在新的会话中将给定的检查点恢复到该图中.最后它会在给定的export_dir_base下面创建一个时间戳导出目录,并在其中写入一个SavedModel,其中包含从此会话保存的单个MetaGraphDef.

导出的MetaGraphDef将为从model_fn返回的export_outputs字典的每个元素提供一个SignatureDef,该字典使用相同的key命名.其中一个key始终为signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服务请求未指定签名时将提供哪个签名.对于每个签名,输出由相应的ExportOutputs提供,并且输入始终是由serving_input_receiver_fn提供的输入接收器.

额外的资产可以通过assets_extra参数写入SavedModel.这应该是一个字典,其中每个key给出与assets.extra目录相关的目标路径(包括文件名).相应的值给出了要复制的源文件的完整路径.例如,在不重命名的情况下复制单个文件的简单情况被指定为{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.

参数:

  • export_dir_base:包含一个目录的字符串,在该目录中创建包含导出的SavedModels的时间戳子目录.
  • serving_input_receiver_fn:一个不带参数并返回一个ServingInputReceiver的函数.
  • assets_extra:指定如何在导出的SavedModel中填充assets.extra目录的字典,如果不需要额外的资产,则为 None.
  • as_text:是否以文本格式编写SavedModel原型.
  • checkpoint_path:要导出的检查点路径.如果None(默认),则选择在模型目录中找到的最近检查点.
  • strip_default_attrs:布尔值.如果True,则将从NodeDefs中删除默认值属性.

返回值:

导出目录的字符串路径.

可能引发的异常:

  • ValueError:如果未提供serving_input_receiver_fn,则不提供export_outputs,或者找不到检查点.

get_variable_names

get_variable_names()

返回此模型中所有变量名称的列表.

返回值:

返回名字列表.

可能引发的异常:

  • ValueError:如果Estimator尚未产生检查点.

get_variable_value

get_variable_value(name)

返回由名称给出的变量的值.

参数:

  • name:字符串或字符串列表,张量的名称.

返回值:

Numpy数组 - 张量的值.

可能引发的异常:

  • ValueError:如果Estimator尚未产生检查点.

latest_checkpoint

latest_checkpoint()

查找model_dir中最新保存的检查点文件的文件名.

返回值:

返回最新检查点的完整路径或None(未找到检查点).

predict

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None
)

对给定的features产生预测.

参数:

  • input_fn:构造特征的函数.预测继续,直到input_fn引发输入端异常(OutOfRangeError或StopIteration).有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列之一:
    • tf.data.Dataset对象:Dataset对象的输出必须具有与下面相同的约束.
    • features:一个Tensor或者名为Tensor的字符串特征的字典.feature被model_fn消耗.他们应该满足model_fn对输入的期望.
    • 一个元组,在这种情况下,第一个项被提取为feature.
  • predict_keys:str列表,要预测的键名称.如果EstimatorSpec.predictions是字典,则使用该方法.如果使用predict_keys,则剩余的预测将从字典中过滤.如果None,则返回全部.
  • hooks:SessionRunHook子类实例列表.用于预测调用中的回调.
  • checkpoint_path:要预测的特定检查点的路径.如果为None,则使用model_dir中的最新的检查点.

返回值:

predictions张量的计算值.

可能引发的异常:

  • ValueError:在model_dir中找不到训练有素的模型.
  • ValueError:如果批次的预测长度不相同.
  • ValueError:如果predict_keys和predictions之间有冲突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一个dict.

train

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

训练给定训练数据input_fn的模型.

参数:

  • input_fn:提供作为minibatches培训的输入数据的函数.有关更多信息,请参阅TensorFlow入门.该函数应该构造并返回下列之一:
    • tf.data.Dataset对象:Dataset对象的输出必须是一个具有相同约束的元组(特征,标签)((features, labels)),其约束条件与下面相同.
    • tuple (features, labels):其中features是一个Tensor或者名为Tensor的字符串特征的字典,labels是一个Tensor或者名为Tensor的字符串标签的字典.这两个特征和标签都由model_fn消耗.他们应该满足model_fn对输入的期望.
  • hooks:SessionRunHook子类实例列表.用于训练循环内的回调.
  • steps:训练模型的步骤数.如果为None,则永远训练或训练直到input_fn产生OutOfRange错误或StopIteration异常.“steps”逐步运作.如果您调用两次train(steps=10),则训练总共发生20个步骤.如果OutOfRange或StopIteration发生在中间,训练在20步之前停止.如果你不想有增量行为,请改为设置.如果设置max_steps,max_steps必须None.
  • max_steps:训练模型的总步骤数.如果为None,则永远训练或训练直到input_fn产生OutOfRange错误或StopIteration异常.如果设置,steps必须None.如果OutOfRange或StopIteration发生在中间,训练在max_steps步骤之前停止.两次调用train(steps=100)意味着200次训练迭代.另一方面,两次调用train(max_steps=100)意味着第二次调用将不会做任何迭代,因为第一次调用完成了所有100个步骤.
  • saving_listeners:CheckpointSaverListener对象列表.用于在检查点节省之前或之后立即执行的回调.

可能引发的异常:

  • ValueError:如果steps和max_steps都不是None.
  • ValueError:如果steps或max_steps其中之一小于等于0.
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号