提交 cb8cdf73 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix label_smoothing for sigmoid_cross_entropy_loss, and test for...

Fix label_smoothing for sigmoid_cross_entropy_loss, and test for softmax_cross_entropy_loss with label_smoothing.

The sigmoid cross entropy with label smoothing was broken, and worse the tests for both that and the softmax cross entropy with label smoothing were broken.  I've fixed both issues here and added comments walking through the two examples in the tests so as to not inadvertently check in broken tests again.
Change: 127100213
上级 a9e87682
...@@ -282,6 +282,15 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0, ...@@ -282,6 +282,15 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0,
label_smoothing=0, scope=None): label_smoothing=0, scope=None):
"""Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits. """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
`weight` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weight` is a
tensor of size [`batch_size`], then the loss weights apply to each
corresponding sample.
If `label_smoothing` is nonzero, smooth the labels towards 1/2:
new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
+ 0.5 * label_smoothing
Args: Args:
logits: [batch_size, num_classes] logits outputs of the network . logits: [batch_size, num_classes] logits outputs of the network .
multi_class_labels: [batch_size, num_classes] target labels in (0, 1). multi_class_labels: [batch_size, num_classes] target labels in (0, 1).
...@@ -292,56 +301,46 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0, ...@@ -292,56 +301,46 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0,
Returns: Returns:
A scalar `Tensor` representing the loss value. A scalar `Tensor` representing the loss value.
Raises:
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid or if `weight` is None.
""" """
with ops.op_scope([logits, multi_class_labels], with ops.op_scope([logits, multi_class_labels],
scope, "sigmoid_cross_entropy_loss"): scope, "sigmoid_cross_entropy_loss"):
return _cross_entropy(logits, multi_class_labels, weight, logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
label_smoothing,
activation_fn=nn.sigmoid_cross_entropy_with_logits)
multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
def softmax_cross_entropy(logits, onehot_labels, weight=1.0, if label_smoothing > 0:
label_smoothing=0, scope=None): multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
"""Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits. 0.5 * label_smoothing)
It can scale the loss by weight factor, and smooth the labels.
Args:
logits: [batch_size, num_classes] logits outputs of the network .
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
weight: Coefficients for the loss. The tensor must be a scalar or a tensor
of shape [batch_size].
label_smoothing: If greater than 0 then smooth the labels.
scope: the scope for the operations performed in computing the loss.
Returns: losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
A scalar `Tensor` representing the loss value. name="xentropy")
""" return _compute_weighted_loss(losses, weight)
with ops.op_scope([logits, onehot_labels],
scope, "softmax_cross_entropy_loss"):
return _cross_entropy(logits, onehot_labels, weight,
label_smoothing,
activation_fn=nn.softmax_cross_entropy_with_logits)
def _cross_entropy(logits, onehot_labels, weight, label_smoothing, def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
activation_fn): label_smoothing=0, scope=None):
"""Adds a CrossEntropyLoss to the losses collection. """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
`weight` acts as a coefficient for the loss. If a scalar is provided, `weight` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weight` is a then the loss is simply scaled by the given value. If `weight` is a
tensor of size [`batch_size`], then the loss weights apply to each tensor of size [`batch_size`], then the loss weights apply to each
corresponding sample. corresponding sample.
If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
Args: Args:
logits: [batch_size, num_classes] logits outputs of the network . logits: [batch_size, num_classes] logits outputs of the network .
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels. onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
weight: Coefficients for the loss. If the activation is SIGMOID, then the weight: Coefficients for the loss. The tensor must be a scalar or a tensor
weight shape must be one of [1], [batch_size] or logits.shape(). of shape [batch_size].
Otherwise, the weight shape must be either [1] or [batch_size].
label_smoothing: If greater than 0 then smooth the labels. label_smoothing: If greater than 0 then smooth the labels.
activation_fn: The activation function to use. The method must take three scope: the scope for the operations performed in computing the loss.
arguments, the logits, the labels, and an operation name.
Returns: Returns:
A scalar `Tensor` representing the loss value. A scalar `Tensor` representing the loss value.
...@@ -350,20 +349,21 @@ def _cross_entropy(logits, onehot_labels, weight, label_smoothing, ...@@ -350,20 +349,21 @@ def _cross_entropy(logits, onehot_labels, weight, label_smoothing,
ValueError: If the shape of `predictions` doesn't match that of `targets` or ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid or if `weight` is None. if the shape of `weight` is invalid or if `weight` is None.
""" """
logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) with ops.op_scope([logits, onehot_labels],
if weight is None: scope, "softmax_cross_entropy_loss"):
raise ValueError("`weight` cannot be None") logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
onehot_labels = math_ops.cast(onehot_labels, logits.dtype) onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
if label_smoothing > 0: if label_smoothing > 0:
num_classes = onehot_labels.get_shape()[1].value num_classes = math_ops.to_float(array_ops.shape(onehot_labels)[1])
smooth_positives = 1.0 - label_smoothing smooth_positives = 1.0 - label_smoothing
smooth_negatives = label_smoothing / num_classes smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives onehot_labels = onehot_labels * smooth_positives + smooth_negatives
losses = activation_fn(logits, onehot_labels, name="xentropy") losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
return _compute_weighted_loss(losses, weight) name="xentropy")
return _compute_weighted_loss(losses, weight)
def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None): def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
......
...@@ -215,13 +215,23 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): ...@@ -215,13 +215,23 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
def testSoftmaxLabelSmoothing(self): def testSoftmaxLabelSmoothing(self):
with self.test_session(): with self.test_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
# \log q_i = x_i - \log \sum_j \exp x_j
# = x_i - x_max - \log \sum_j \exp (x_j - x_max)
# For our activations, [100, -100, -100] the log partion function becomes
# \log ( exp(0) + exp(-200) + exp(-200) ) = 0
# so our log softmaxes become: [0, -200, -200]
# so our cross entropy loss is:
# -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n
logits = tf.constant([[100.0, -100.0, -100.0]]) logits = tf.constant([[100.0, -100.0, -100.0]])
labels = tf.constant([[1, 0, 0]]) labels = tf.constant([[1, 0, 0]])
label_smoothing = 0.1 label_smoothing = 0.1
loss = tf.contrib.losses.sigmoid_cross_entropy( loss = tf.contrib.losses.softmax_cross_entropy(
logits, labels, label_smoothing=label_smoothing) logits, labels, label_smoothing=label_smoothing)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
expected_value = 400.0 * label_smoothing / 9.0 expected_value = 400.0 * label_smoothing / 3.0
self.assertAlmostEqual(loss.eval(), expected_value, 3) self.assertAlmostEqual(loss.eval(), expected_value, 3)
...@@ -283,14 +293,39 @@ class SigmoidCrossEntropyLossTest(tf.test.TestCase): ...@@ -283,14 +293,39 @@ class SigmoidCrossEntropyLossTest(tf.test.TestCase):
def testSigmoidLabelSmoothingCorrect(self): def testSigmoidLabelSmoothingCorrect(self):
with self.test_session(): with self.test_session():
logits = tf.constant([[100.0, -100.0, -100.0]]) logits = tf.constant([[100.0, -100.0, -100.0]])
labels = tf.constant([[1, 0, 0]]) labels = tf.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
# max(x,0) - x*z + log(1 + exp(-abs(x)))
# The new labels are:
# z' = z * (1 - L) + 0.5 L
# 1 -> 1 - 0.5 L
# 0 -> 0.5 L
# here we expect:
# 1/3 * (100 - 100 * (1 - 0.5 L) + 0
# + 0 + 100 * (0.5 L) + 0
# + 0 + 100 * (1 - 0.5 L) + 0)
# = 1/3 * (100 + 50 L)
label_smoothing = 0.1 label_smoothing = 0.1
loss = tf.contrib.losses.sigmoid_cross_entropy( loss = tf.contrib.losses.sigmoid_cross_entropy(
logits, labels, label_smoothing=label_smoothing) logits, labels, label_smoothing=label_smoothing)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
expected_value = 400.0 * label_smoothing / 9.0 expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
self.assertAlmostEqual(loss.eval(), expected_value, 3) self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
with self.test_session():
label_smoothing = 0.1
sigmoid_logits = tf.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = tf.constant([[1, 0, 1]])
sigmoid_loss = tf.contrib.losses.sigmoid_cross_entropy(
sigmoid_logits, sigmoid_labels, label_smoothing=label_smoothing)
softmax_logits = tf.constant([[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]])
softmax_labels = tf.constant([[0, 1], [1, 0], [0, 1]])
softmax_loss = tf.contrib.losses.softmax_cross_entropy(
softmax_logits, softmax_labels, label_smoothing=label_smoothing)
self.assertAlmostEqual(sigmoid_loss.eval(), softmax_loss.eval(), 3)
class LogLossTest(tf.test.TestCase): class LogLossTest(tf.test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册