TensorFlow函数:tf.estimator.RunConfig

由 Carrie 创建, 最后一次修改 2018-05-07

tf.estimator.RunConfig函数

RunConfig类

定义在:tensorflow/python/estimator/run_config.py.

该类指定Estimator运行的配置.

属性

  • cluster_spec
  • evaluation_master
  • global_id_in_cluster

    该global_id_in_cluster属性表示训练集群中的全局标识.

    训练集群中的所有全局ID都是从递增的连续整数序列中分配的,第一个ID是0.

    注意:任务ID(属性字段task_id)正在跟踪具有SAME任务类型的所有节点中的节点索引.例如,给定集群定义如下:
    cluster = {'chief': ['host0:2222'],
               'ps': ['host1:2222', 'host2:2222'],
               'worker': ['host3:2222', 'host4:2222', 'host5:2222']}

    具有任务类型worker的节点可以具有id 0,1,2.具有任务类型ps的节点可以具有id,0,1.因此,task_id不是唯一的,但pair(task_type,task_id)可以唯一确定集群中的节点.

    全局ID即该字段正在跟踪集群中所有节点之间的节点索引.它是唯一分配的.例如,对于上面给出的集群规范,全局id分配为:

    task_type  | task_id  |  global_id
    --------------------------------
    chief      | 0        |  0
    worker     | 0        |  1
    worker     | 1        |  2
    worker     | 2        |  3
    ps         | 0        |  4
    ps         | 1        |  5

    返回:

    一个整数ID.

  • is_chief
  • keep_checkpoint_every_n_hours
  • keep_checkpoint_max
  • log_step_count_steps
  • master
  • model_dir
  • num_ps_replicas
  • num_worker_replicas
  • save_checkpoints_secs
  • save_checkpoints_steps
  • save_summary_steps
  • service

    返回定义的平台(在TF_CONFIG中)服务字典.

  • session_config
  • task_id
  • task_type
  • tf_random_seed

方法

__init__

__init__(
    model_dir=None,
    tf_random_seed=None,
    save_summary_steps=100,
    save_checkpoints_steps=_USE_DEFAULT,
    save_checkpoints_secs=_USE_DEFAULT,
    session_config=None,
    keep_checkpoint_max=5,
    keep_checkpoint_every_n_hours=10000,
    log_step_count_steps=100
)

该方法用于构造一个RunConfig.

所有的分布式训练相关的属性cluster_spec,is_chief,master,num_worker_replicas,num_ps_replicas,task_id和task_type都是基于 TF_CONFIG 环境变量设置的,如果相关的信息存在.TF_CONFIG环境变量是具有属性JSON对象:cluster和task.

cluster是ClusterSpec的Python字典的JSON序列化版本,它将server_lib.py任务类型(通常是TaskType枚举之一)映射到任务地址列表.

task有两个属性:type和index,其中,type可以是cluster中任何类型的任务.当TF_CONFIG包含所述信息,则在该类上设置以下属性:

  • cluster_spec:该属性从TF_CONFIG['cluster']解析,默认为{},如果存在,则在cluster_spec的chief属性中必须有且仅有一个节点.
  • task_type:设置为TF_CONFIG['task']['type'];如果cluster_spec存在,则必须设置;如果cluster_spec没有设置,则必须是worker(默认值).
  • task_id:设置为TF_CONFIG['task']['index'];如果cluster_spec存在,必须设置;如果cluster_spec未设置,则必须为0(默认值).
  • master:master属性是通过在cluster_spec中查找task_type和task_id来确定的,默认为''.
  • num_ps_replicas:是通过计算cluster_spec的ps属性中列出的节点数来设置的,默认为0.
  • num_worker_replicas:是通过计算cluster_spec中的worker和chief属性中列出的节点数来设置的,默认为1.
  • is_chief:是基于task_type和cluster确定的.

有一个带有task_type作为计算器的特殊节点,它不是(训练)cluster_spec的一部分,它处理分布式计算作业.

non-chief节点的例子:

cluster = {'chief': ['host0:2222'],
           'ps': ['host1:2222', 'host2:2222'],
           'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
    {'cluster': cluster,
     'task': {'type': 'worker', 'index': 1}})
config = ClusterConfig()
assert config.master == 'host4:2222'
assert config.task_id == 1
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'worker'
assert not config.is_chief

chief的例子:

cluster = {'chief': ['host0:2222'],
           'ps': ['host1:2222', 'host2:2222'],
           'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
    {'cluster': cluster,
     'task': {'type': 'chief', 'index': 0}})
config = ClusterConfig()
assert config.master == 'host0:2222'
assert config.task_id == 0
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'chief'
assert config.is_chief

evaluator节点示例(evaluator不是训练集群的一部分):

cluster = {'chief': ['host0:2222'],
           'ps': ['host1:2222', 'host2:2222'],
           'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
    {'cluster': cluster,
     'task': {'type': 'evaluator', 'index': 0}})
config = ClusterConfig()
assert config.master == ''
assert config.evaluator_master == ''
assert config.task_id == 0
assert config.num_ps_replicas == 0
assert config.num_worker_replicas == 0
assert config.cluster_spec == {}
assert config.task_type == 'evaluator'
assert not config.is_chief

注意:如果save_checkpoints_steps或save_checkpoints_secs已设置,keep_checkpoint_max可能需要进行相应调整,特别是在分布式训练中.例如,设置save_checkpoints_secs为60而不进行调整keep_checkpoint_max(默认为5)会导致检查点在5分钟后被垃圾收集的情况.在分布式训练中,计算作业异步启动,可能无法加载或由于竞争条件而找到检查点.

参数:

  • model_dir:保存模型参数,图表等的目录.如果有PathLike对象,路径将被解析;如果为None,则将使用Estimator设置的默认值.
  • tf_random_seed:TensorFlow初始化器的随机种子,设置此值可以实现重播之间的一致性.
  • save_summary_steps:每隔这么多步骤保存摘要.
  • save_checkpoints_steps:每隔这么多步骤保存检查点,不能用save_checkpoints_secs指定.
  • save_checkpoints_secs:每隔几秒钟保存检查点,不能用save_checkpoints_steps指定;如果save_checkpoints_steps和save_checkpoints_secs在构造函数中未设置,则默认设置为600秒;如果两个save_checkpoints_steps和save_checkpoints_secs为None,则检查站被禁用.
  • session_config:用于设置会话参数的ConfigProto,或None.
  • keep_checkpoint_max:要保留的最近检查点文件的最大数量.当新文件被创建时,旧文件被删除.如果为None或0,则保留所有检查点文件.默认为5(也就是保留5个最近的检查点文件.)
  • keep_checkpoint_every_n_hours:要保存的每个检查点之间的小时数;默认值10,000小时有效地禁用该功能.
  • log_step_count_steps:在培训期间将记录全局步骤/秒(global step/sec)的频率 (以全局步骤数表示).

可能引发的异常:

  • ValueError:如果同时设置save_checkpoints_steps和save_checkpoints_secs.

replace

replace(**kwargs)

返回RunConfig的新实例替换指定属性.

仅允许替换以下列表中的属性:

  • model_dir
  • tf_random_seed
  • save_summary_steps
  • save_checkpoints_steps
  • save_checkpoints_secs
  • session_config
  • keep_checkpoint_max
  • keep_checkpoint_every_n_hours
  • log_step_count_steps

另外,可以设置save_checkpoints_steps或者save_checkpoints_secs(不应该同时设置).

参数:

  • **kwargs:使用新值命名属性的关键字.

可能引发的异常:

  • ValueError:如果任何属性名kwargs不存在或不允许被替换,或同时设置save_checkpoints_steps和save_checkpoints_secs.

返回值:

一个RunConfig的新的实例.

以上内容是否对您有帮助:
二维码
建议反馈
二维码