contrib.training.rejection_sample
tf.contrib.training.rejection_sample
tf.contrib.training.rejection_sample
rejection_sample( tensors, accept_prob_fn, batch_size, queue_threads=1, enqueue_many=False, prebatch_capacity=16, prebatch_threads=1, runtime_checks=False, name=None )
Defined in tensorflow/contrib/training/python/training/sampling_ops.py
.
See the guide: Training (contrib) > Online data resampling
Stochastically creates batches by rejection sampling.
Each list of non-batched tensors is evaluated by accept_prob_fn
, to produce a scalar tensor between 0 and 1. This tensor corresponds to the probability of being accepted. When batch_size
tensor groups have been accepted, the batch queue will return a mini-batch.
Args:
-
tensors
: List of tensors for data. All tensors are either one item or a batch, according to enqueue_many. -
accept_prob_fn
: A python lambda that takes a non-batch tensor from each item intensors
, and produces a scalar tensor. -
batch_size
: Size of batch to be returned. -
queue_threads
: The number of threads for the queue that will hold the final batch. -
enqueue_many
: Bool. If true, interpret input tensors as having a batch dimension. -
prebatch_capacity
: Capacity for the large queue that is used to convert batched tensors to single examples. -
prebatch_threads
: Number of threads for the large queue that is used to convert batched tensors to single examples. -
runtime_checks
: Bool. If true, insert runtime checks on the output ofaccept_prob_fn
. UsingTrue
might have a performance impact. -
name
: Optional prefix for ops created by this function.
Raises:
-
ValueError
: enqueue_many is True and labels doesn't have a batch dimension, or if enqueue_many is False and labels isn't a scalar. -
ValueError
: enqueue_many is True, and batch dimension on data and labels don't match. -
ValueError
: if a zero initial probability class has a nonzero target probability.
Returns:
A list of tensors of the same length as tensors
, with batch dimension batch_size
.
Example: # Get tensor for a single data and label example. data, label = data_provider.Get(['data', 'label'])
# Get stratified batch according to data tensor. accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2 data_batch = tf.contrib.training.rejection_sample( [data, label], accept_prob_fn, 16)
# Run batch through network. ...
© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/training/rejection_sample