TensorFlow Training函数

由 Carrie 创建, 最后一次修改 2017-08-26

tf.train 提供了一组帮助训练模型的类和函数。

优化器

优化器基类提供了计算渐变的方法,并将渐变应用于变量。子类的集合实现了经典的优化算法,如 GradientDescent和Adagrad。

您永远不会实例化优化器类本身,而是实例化其中一个子类。

  • tf.train.Optimizer
  • tf.train.GradientDescentOptimizer
  • tf.train.AdadeltaOptimizer
  • tf.train.AdagradOptimizer
  • tf.train.AdagradDAOptimizer
  • tf.train.MomentumOptimizer
  • tf.train.AdamOptimizer
  • tf.train.FtrlOptimizer
  • tf.train.ProximalGradientDescentOptimizer
  • tf.train.ProximalAdagradOptimizer
  • tf.train.RMSPropOptimizer

梯度计算

TensorFlow 提供了计算给定 TensorFlow 计算图的导数的函数,并将运算添加到图中。优化器类自动在您的关系图上计算派生,但是新的优化或专家用户的创建者可以调用下面的低级函数。

  • tf.gradients
  • tf.AggregationMethod
  • tf.stop_gradient
  • tf.hessians

梯度剪辑

TensorFlow 提供了几种操作,您可以使用它们为您的图形添加剪切功能。您可以使用这些功能执行一般的数据剪辑,但它们对于处理已推翻或消失的渐变特别有用。

  • tf.clip_by_value
  • tf.clip_by_norm
  • tf.clip_by_average_norm
  • tf.clip_by_global_norm
  • tf.global_norm

降低学习率

  • tf.train.exponential_decay
  • tf.train.inverse_time_decay
  • tf.train.natural_exp_decay
  • tf.train.piecewise_constant
  • tf.train.polynomial_decay

移动平均线

一些训练算法,例如 GradientDescent 和动量,通常会在优化过程中保持变量的移动平均值而受益。使用移动平均值进行评估通常会显著改善结果。

  • tf.train.ExponentialMovingAverage

协调员和 QueueRunner

有关如何使用线程和队列的操作,请参见线程和队列。有关队列 API 的文档,请参见队列。

  • tf.train.Coordinator
  • tf.train.QueueRunner
  • tf.train.LooperThread
  • tf.train.add_queue_runner
  • tf.train.start_queue_runners

分布式执行

分布式执行
有关如何配置分布式 TensorFlow 程序的详细信息,请参阅分布式 TensorFlow。

  • tf.train.Server
  • tf.train.Supervisor
  • tf.train.SessionManager
  • tf.train.ClusterSpec
  • tf.train.replica_device_setter
  • tf.train.MonitoredTrainingSession
  • tf.train.MonitoredSession
  • tf.train.SingularMonitoredSession
  • tf.train.Scaffold
  • tf.train.SessionCreator
  • tf.train.ChiefSessionCreator
  • tf.train.WorkerSessionCreator

从事件文件中读取摘要

有关摘要、事件文件和 TensorBoard 中的可视化的概述,请参见摘要和 TensorBoard。

  • tf.train.summary_iterator

Training Hooks

Hooks 是在模型的训练/评估过程中运行的工具:

  • tf.train.SessionRunHook
  • tf.train.SessionRunArgs
  • tf.train.SessionRunContext
  • tf.train.SessionRunValues
  • tf.train.LoggingTensorHook
  • tf.train.StopAtStepHook
  • tf.train.CheckpointSaverHook
  • tf.train.NewCheckpointReader
  • tf.train.StepCounterHook
  • tf.train.NanLossDuringTrainingError
  • tf.train.NanTensorHook
  • tf.train.SummarySaverHook
  • tf.train.GlobalStepWaiterHook
  • tf.train.FinalOpsHook
  • tf.train.FeedFnHook

Training 工具

  • tf.train.global_step
  • tf.train.basic_train_loop
  • tf.train.get_global_step
  • tf.train.assert_global_step
  • tf.train.write_graph
以上内容是否对您有帮助:
二维码
建议反馈
二维码