TensorFlow函数教程:tf.nn.nce_loss

2020-10-20 15:17 更新

tf.nn.nce_loss函数

tf.nn.nce_loss(
    weights,
    biases,
    labels,
    inputs,
    num_sampled,
    num_classes,
    num_true=1,
    sampled_values=None,
    remove_accidental_hits=False,
    partition_strategy='mod',
    name='nce_loss'
)

定义在:tensorflow/python/ops/nn_impl.py.

请参阅指南:神经网络>候选采样

计算并返回噪声对比估计(NCE, Noise Contrastive Estimation)训练损失.

一个常见的用例是使用此方法进行训练,并计算完整的S形模型损失以进行评估或推断.在这种情况下,您必须将partition_strategy="div",使两个损失保持一致,如下例所示:

if mode == "train":
  loss = tf.nn.nce_loss(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      ...,
      partition_strategy="div")
elif mode == "eval":
  logits = tf.matmul(inputs, tf.transpose(weights))
  logits = tf.nn.bias_add(logits, biases)
  labels_one_hot = tf.one_hot(labels, n_classes)
  loss = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=labels_one_hot,
      logits=logits)
  loss = tf.reduce_sum(loss, axis=1)
注意:默认情况下,它使用对数均匀(Zipfian)分布进行采样,因此必须按照频率递减的顺序对标签进行排序,以获得良好的结果.有关详细信息,请参阅tf.nn.log_uniform_candidate_sampler.
注意:在num_true> 1 的情况下,我们为每个目标类分配目标概率1/num_true,以便目标概率总和为每个示例1.
注意:每个示例允许目标类的变量数量是有用的.我们希望在将来的版本中提供此功能.现在,如果你有一个目标类的变量数量,你可以通过重复它们或通过填充其他未使用的类来将它们填充到一个常数. 

参数:

  • weights:一个Tensor,shape[num_classes, dim],或者是Tensor对象列表,其沿着维度0的连接具有shape [num_classes,dim].(可能是分区的)类嵌入.
  • biases:一个Tensor,shape[num_classes].类偏差.
  • labels:一个Tensor,类型为int64shape [batch_size, num_true].目标类.
  • inputs:一个Tensor,shape [batch_size, dim].输入网络的正向激活.
  • num_sampled:int,每批随机抽样的类数.
  • num_classes:int,可能的类数.
  • num_true:int,每个训练示例的目标类数.
  • sampled_values:由* _candidate_sampler函数返回的元组(sampled_candidates,true_expected_count,sampled_expected_count).(如果是None,我们默认为log_uniform_candidate_sampler)
  • remove_accidental_hits:bool.是否删除“意外命中”,其中采样类等于其中一个目标类.如果设置为True,则这是“Sampled Logistic”损失而不是NCE,我们正在学习生成对数赔率而不是对数概率.请参阅[候选采样算法参考](https://www.tensorflow.org/extras/candidate_sampling.pdf).默认值为False.
  • partition_strategy:指定分区策略的字符串,如果len(weights) > 1.目前"div""mod"受到支持.默认是"mod".更多详细信息,请参阅tf.nn.embedding_lookup
  • name:操作的名称(可选).

返回:

每个示例NCE损失的batch_size 1-D张量.


以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号