TensorFlow服务时间内的接收设置

2018-10-09 18:31 更新

#版权所有2017 TensorFlow作者.版权所有.

#根据Apache许可证版本2.0(“许可证”)许可;

#除非符合许可证,否则您不得使用此文件.

#您可以获得许可证的副本

#http://www.apache.org/licenses/LICENSE-2.0

#除非适用法律要求或书面同意软件

根据许可证分发的#分发在“按原样”基础上,

#无明示或暗示的任何种类的保证或条件.

#查看有关权限的特定语言的许可证

许可证下的#限制.

# =============================================== =============================

""在服务时间接收输入的配置和实用程序""

from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import os import time import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.util import compat _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', ['features', 'receiver_tensors'])): """A return type for a serving_input_receiver_fn. The expected return values are: features: A dict of string to `Tensor` or `SparseTensor`, specifying the features to be passed to the model. receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying input nodes where this receiver expects to be fed. Typically, this is a single placeholder expecting serialized `tf.Example` protos. """ # TODO(soergel): add receiver_alternatives when supported in serving. def __new__(cls, features, receiver_tensors): if features is None: raise ValueError('features must be defined.') if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): if not isinstance(name, six.string_types): raise ValueError('feature keys must be strings: {}.'.format(name)) if not (isinstance(tensor, ops.Tensor) or isinstance(tensor, sparse_tensor.SparseTensor)): raise ValueError( 'feature {} must be a Tensor or SparseTensor.'.format(name)) if receiver_tensors is None: raise ValueError('receiver_tensors must be defined.') if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} for name, tensor in receiver_tensors.items(): if not isinstance(name, six.string_types): raise ValueError( 'receiver_tensors keys must be strings: {}.'.format(name)) if not isinstance(tensor, ops.Tensor): raise ValueError( 'receiver_tensor {} must be a Tensor.'.format(name)) return super(ServingInputReceiver, cls).__new__( cls, features=features, receiver_tensors=receiver_tensors) def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): """Build a serving_input_receiver_fn expecting fed tf.Examples. Creates a serving_input_receiver_fn that expects a serialized tf.Example fed into a string placeholder. The function parses the tf.Example according to the provided feature_spec, and returns all parsed Tensors as features. Args: feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`. default_batch_size: the number of query examples expected per batch. Leave unset for variable batch size (recommended). Returns: A serving_input_receiver_fn suitable for use in serving. """ def serving_input_receiver_fn(): """An input_fn that expects a serialized tf.Example.""" serialized_tf_example = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') receiver_tensors = {'examples': serialized_tf_example} features = parsing_ops.parse_example(serialized_tf_example, feature_spec) return ServingInputReceiver(features, receiver_tensors) return serving_input_receiver_fn def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. Creates an serving_input_receiver_fn that expects all features to be fed directly. Args: features: a dict of string to `Tensor`. default_batch_size: the number of query examples expected per batch. Leave unset for variable batch size (recommended). Returns: A serving_input_receiver_fn. """ def serving_input_receiver_fn(): """A serving_input_receiver_fn that expects features to be fed directly.""" receiver_tensors = {} for name, t in features.items(): shape_list = t.get_shape().as_list() shape_list[0] = default_batch_size shape = tensor_shape.TensorShape(shape_list) # Reuse the feature tensor name for the placeholder, excluding the index placeholder_name = t.name.split(':')[0] receiver_tensors[name] = array_ops.placeholder(dtype=t.dtype, shape=shape, name=placeholder_name) # TODO(b/34885899): remove the unnecessary copy # The features provided are simply the placeholders, but we defensively copy # the dict because it may be mutated. return ServingInputReceiver(receiver_tensors, receiver_tensors.copy()) return serving_input_receiver_fn ### Below utilities are specific to SavedModel exports. def build_all_signature_defs(receiver_tensors, export_outputs): """Build `SignatureDef`s for all export outputs.""" if not isinstance(receiver_tensors, dict): receiver_tensors = {'receiver': receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): raise ValueError('export_outputs must be a dict.') signature_def_map = { '{}'.format(output_key or 'None'): export_output.as_signature_def(receiver_tensors) for output_key, export_output in export_outputs.items()} return signature_def_map def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. Each export is written into a new subdirectory named using the current time. This guarantees monotonically increasing version numbers even across multiple runs of the pipeline. The timestamp used is the number of seconds since epoch UTC. Args: export_dir_base: A string containing a directory to write the exported graph and checkpoints. Returns: The full path of the new subdirectory (which is not actually created yet). """ export_timestamp = int(time.time()) export_dir = os.path.join( compat.as_bytes(export_dir_base), compat.as_bytes(str(export_timestamp))) return export_dir
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号