TensorFlow函数教程:tf.nn.learned_unigram_candidate_sampler

由 Carrie 创建, 最后一次修改 2019-01-07

tf.nn.learned_unigram_candidate_sampler函数

tf.nn.learned_unigram_candidate_sampler(
    true_classes,
    num_true,
    num_sampled,
    unique,
    range_max,
    seed=None,
    name=None
)

定义在:tensorflow/python/ops/candidate_sampling_ops.py.

请参阅指南:神经网络>候选采样

从训练期间学习的分布中抽取一组类样本.

该操作从整数范围[0, range_max)中随机采样一个采样类(sampled_candidates)的张量.

sampled_candidates的元素是在没有替换 (如果unique=True) 或替换 (如果unique=False) 的基础分布中绘制的.

该操作的基本分布在训练期间即时构建.这是迄今为止在训练期间看到的目标类别的单一分布.[0,range_max]中的每个整数都以权重1开始,并且每次被视为目标类时都会增加1.基本分布不会保存到检查点,因此在重新加载模型时会重置它.

此外,此操作返回张量true_expected_count并sampled_expected_count,表示每个目标类(true_classes)和采样类(sampled_candidates)预期在平均张量的采样类中出现的次数.如果unique=True,则这些是拒绝后概率,我们大致计算它们.

参数:

  • true_classes:一个Tensor,类型为int64,shape为[batch_size, num_true].目标类.
  • num_true:int,每个训练示例的目标类数.
  • num_sampled:int,随机采样的类数.
  • unique:bool,确定批处理中的所有采样类是否都是唯一的.
  • range_max:int,可能的类数.
  • seed:int,特定于操作的seed.默认值为0.
  • name:操作的名称(可选).

返回:

  • sampled_candidates:类型为int64和shape为[num_sampled]的Tensor.采样类.
  • true_expected_count:类型为float的Tensor.shape与true_classes相同.每个true_classes的采样分布下的预期计数.
  • sampled_expected_count:类型为float的Tensor.shape与sampled_candidates相同.每个sampled_candidates的采样分布下的预期计数.
以上内容是否对您有帮助:
二维码
建议反馈
二维码