contrib.seq2seq.SampleEmbeddingHelper
tf.contrib.seq2seq.SampleEmbeddingHelper
class tf.contrib.seq2seq.SampleEmbeddingHelper
Defined in tensorflow/contrib/seq2seq/python/ops/helper.py
.
A helper for use during inference.
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
Properties
batch_size
Methods
__init__
__init__( embedding, start_tokens, end_token, seed=None )
Initializer.
Args:
-
embedding
: A callable that takes a vector tensor ofids
(argmax ids), or theparams
argument forembedding_lookup
. The returned tensor will be passed to the decoder input. -
start_tokens
:int32
vector shaped[batch_size]
, the start tokens. -
end_token
:int32
scalar, the token that marks end of decoding. -
seed
: The sampling seed.
Raises:
-
ValueError
: ifstart_tokens
is not a 1D tensor orend_token
is not a scalar.
initialize
initialize(name=None)
next_inputs
next_inputs( time, outputs, state, sample_ids, name=None )
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample( time, outputs, state, name=None )
sample for SampleEmbeddingHelper.
© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/SampleEmbeddingHelper