TensorFlow函数:tf.estimator.WarmStartSettings

由 Carrie 创建, 最后一次修改 2018-05-09

tf.estimator.WarmStartSettings函数

WarmStartSettings类

定义在:tensorflow/python/estimator/warm_starting_util.py.

在Estimators中进行warm-starting的设置.

示例:使用 DNNEstimator 罐头

emb_vocab_file = tf.feature_column.embedding_column(
    tf.feature_column.categorical_column_with_vocabulary_file(
        "sc_vocab_file", "new_vocab.txt", vocab_size=100),
    dimension=8)
emb_vocab_list = tf.feature_column.embedding_column(
    tf.feature_column.categorical_column_with_vocabulary_list(
        "sc_vocab_list", vocabulary_list=["a", "b"]),
    dimension=8)
estimator = tf.estimator.DNNClassifier(
  hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
  warm_start_from=ws)

其中ws可以定义为:

模型中warm-start的所有权重(输入层和隐藏权重).可以提供目录或特定的检查点(在前者的情况下,将使用最新的检查点):

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")

仅warm-start启动嵌入(输入层)及其累加器变量:

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
                       vars_to_warm_start=".*input_layer.*")

warm-start除优化器累加器变量(DNN默认为Adagrad)之外的所有内容:

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
                       vars_to_warm_start="^(?!.*(Adagrad))")

warm-start所有权重,但与sc_vocab_file对应的嵌入参数与当前模型中使用的词汇不同:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

仅warm-start sc_vocab_file嵌入(并且没有其他变量),它们与当前模型中使用的词汇不同:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    vars_to_warm_start=None,
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

对所有权重进行warm-start,但sc_vocab_file对应的参数与当前检查点中使用的词汇不同,只有100个项被使用:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt",
    old_vocab_size=100
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

warm-start所有权重,但sc_vocab_file对应的参数与当前检查点中使用的词汇不同,sc_vocab_list对应的参数与当前检查点有不同的名称:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt",
    old_vocab_size=100
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    },
    var_name_to_prev_var_name={
        "input_layer/sc_vocab_list_embedding/embedding_weights":
            "old_tensor_name"
    })

属性:

  • ckpt_to_initialize_from:[必需]一个字符串,用于指定具有检查点文件的目录或检查点的路径,以便从中启动模型参数.
  • vars_to_warm_start:[可选]一个正则表达式,用于捕获要启动哪个变量.默认为'.*',它会warm-start所有变量.如果None明确给出,只有var_name_to_vocab_info中指定的变量将被warm-start.
  • var_name_to_vocab_info:[可选]字典变量名称(字符串)的VocabInfo.变量名称应该是“完整的”变量,而不是分区的名称.如果没有明确提供,则假定该变量没有词汇表.
  • var_name_to_prev_var_name:[可选]将变量名称(字符串)指定为之前ckpt_to_initialize_from中训练的变量的名称.如果未明确提供,则假定变量的名称在前一个检查点和当前模型之间相同.

函数属性

  • ckpt_to_initialize_from

    字段编号0的别名

  • var_name_to_prev_var_name

    字段编号3的别名

  • var_name_to_vocab_info

    字段编号2的别名

  • vars_to_warm_start

    字段编号1的别名

函数方法

__new__

@staticmethod
__new__(
    cls,
    ckpt_to_initialize_from,
    vars_to_warm_start='.*',
    var_name_to_vocab_info=None,
    var_name_to_prev_var_name=None
)
以上内容是否对您有帮助:

二维码
建议反馈
二维码