TensorFlow定义生成常量的操作

2018-09-19 09:57 更新

#版权所有2015 TensorFlow作者.版权所有.

#根据Apache许可证版本2.0(“许可证”)许可;

#除非符合许可证,否则您不得使用此文件.

#您可以获得许可证的副本

#http      ://www.apache.org/licenses/LICENSE-2.0

#除非适用法律要求或书面同意软件

根据许可证分发的#分发在“按原样”基础上,

#无明示或暗示的任何种类的保证或条件.

#查看有关权限的特定语言的许可证

许可证下的#限制.

# =============================================== =============================

""生成常量的操作.""

请参阅@ {$ python / constant_op $ constants guide}.

@@zeros @@zeros_like @@ones @@ones_like @@fill @@constant @@linspace @@range @@random_normal @@truncated_normal @@random_uniform @@random_shuffle @@random_crop @@multinomial @@random_gamma @@random_poisson @@set_random_seed """ # Must be separate from array_ops to avoid a cyclic dependency. from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. The resulting tensor is populated with values of type `dtype`, as specified by arguments `value` and (optionally) `shape` (see examples below). The argument `value` can be a constant value, or a list of values of type `dtype`. If `value` is a list, then the length of the list must be less than or equal to the number of elements implied by the `shape` argument (if specified). In the case where the list length is less than the number of elements specified by `shape`, the last element in the list will be used to fill the remaining entries. The argument `shape` is optional. If present, it specifies the dimensions of the resulting tensor. If not present, the shape of `value` is used. If the argument `dtype` is not specified, then the type is inferred from the type of `value`. For example: ```python # Constant 1-D Tensor populated with value list. tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7] # Constant 2-D tensor populated with scalar value -1. tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] [-1. -1. -1.]] ``` Args: value: A constant value (or list) of output type `dtype`. dtype: The type of the elements of the resulting tensor. shape: Optional dimensions of resulting tensor. name: Optional name for the tensor. verify_shape: Boolean that enables verification of a shape of values. Returns: A Constant Tensor. """ g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op( "Const", [], [dtype_value.type], attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0] return const_tensor def _constant_tensor_conversion_function(v, dtype=None, name=None, as_ref=False): _ = as_ref return constant(v, dtype=dtype, name=name) ops.register_tensor_conversion_function( (list, tuple), _constant_tensor_conversion_function, 100) ops.register_tensor_conversion_function( np.ndarray, _constant_tensor_conversion_function, 100) ops.register_tensor_conversion_function( np.generic, _constant_tensor_conversion_function, 100) ops.register_tensor_conversion_function( object, _constant_tensor_conversion_function, 200) def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None, as_ref=False): _ = as_ref if not s.is_fully_defined(): raise ValueError( "Cannot convert a partially known TensorShape to a Tensor: %s" % s) s_list = s.as_list() int64_value = 0 for dim in s_list: if dim >= 2**31: int64_value = dim break if dtype is not None: if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) if dtype == dtypes.int32 and int64_value: raise ValueError("Cannot convert a TensorShape to dtype int32; " "a dimension is too large (%s)" % int64_value) else: dtype = dtypes.int64 if int64_value else dtypes.int32 if name is None: name = "shape_as_tensor" return constant(s_list, dtype=dtype, name=name) ops.register_tensor_conversion_function( tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100) def _dimension_tensor_conversion_function(d, dtype=None, name=None, as_ref=False): _ = as_ref if d.value is None: raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) if dtype is not None: if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) else: dtype = dtypes.int32 if name is None: name = "shape_as_tensor" return constant(d.value, dtype=dtype, name=name) ops.register_tensor_conversion_function( tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号