TensorFlow函数教程:tf.nn.weighted_cross_entropy_with_logits

tf.nn.weighted_cross_entropy_with_logits函数

tf.nn.weighted_cross_entropy_with_logits(
    targets,
    logits,
    pos_weight,
    name=None
)

定义在:tensorflow/python/ops/nn_impl.py。

计算加权交叉熵。

类似于sigmoid_cross_entropy_with_logits(),除了pos_weight,允许人们通过向上或向下加权相对于负误差的正误差的成本来权衡召回率和精确度。

通常的交叉熵成本定义为:

targets * -log(sigmoid(logits)) +
    (1 - targets) * -log(1 - sigmoid(logits))

值pos_weights > 1减少了假阴性计数,从而增加了召回率。相反设置pos_weights < 1会减少假阳性计数并提高精度。从一下内容可以看出pos_weight是作为损失表达式中的正目标项的乘法系数引入的:

targets * -log(sigmoid(logits)) * pos_weight +
    (1 - targets) * -log(1 - sigmoid(logits))

为了简便起见,让x = logits,z = targets,q = pos_weight。损失是:

  qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

设置l = (1 + (q - 1) * z),确保稳定性并避免溢出,使用一下内容来实现:

(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

logits和targets必须具有相同的类型和形状。

参数:

  • targets:一个Tensor,与logits具有相同的类型和形状。
  • logits:一个Tensor,类型为float32或float64。
  • pos_weight:正样本中使用的系数。
  • name:操作的名称(可选)。

返回:

与具有分量加权逻辑损失的logits具有相同形状的Tensor。

可能引发的异常:

  • ValueError:如果logits和targets没有相同的形状。
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

意见反馈
返回顶部