diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 40da78d8cf41cc3d9a6dc7957b58580b0dde21bf..67a296678a4acf4a72aaf239f99fc10b02d9ae24 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -176,30 +176,19 @@ def cross_entropy( "target_ndim={}".format(n0, n1) ) - num_classes = pred.shape[axis] - no_label_smooth = ( - label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0 - ) + ls = label_smooth + if with_logits: + logZ = logsumexp(pred, axis).mean() + primary_term = indexing_one_hot(pred, label, axis).mean() + else: + logZ = 0 + primary_term = log(indexing_one_hot(pred, label, axis)).mean() + if ls is None or type(ls) in (int, float) and ls == 0: + return logZ - primary_term if not with_logits: - if no_label_smooth: - return -log(indexing_one_hot(pred, label, axis)).mean() pred = log(pred) - return ( - label_smooth * pred.mean() - - (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean() - ) - - # Denominator of the softmax - down = logsumexp(pred, axis=axis, keepdims=True) - - up = indexing_one_hot(pred, label, axis) - - if not no_label_smooth: - factor = label_smooth / num_classes - up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor - - return (down - up).mean() + return logZ - ls * pred.mean() - (1 - ls) * primary_term def binary_cross_entropy( diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index 8bfd1cd5e24d83d27fc2c4604100af954f157668..7774a9f0a4701985b24c2e621369ef2996bd9139 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -13,15 +13,15 @@ from megengine import tensor def test_cross_entropy_with_logits(): - data = tensor([1, 100]).astype(np.float32).reshape((1, 2)) - label = tensor([1]).astype(np.int32) + data = tensor([[0, 50], [0, -150]]).astype(np.float32) + label = tensor([1, 0]).astype(np.int32) loss = F.nn.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 0.0) - label = tensor([0]).astype(np.int32) + label = tensor([0, 1]).astype(np.int32) loss = F.nn.cross_entropy(data, label) - np.testing.assert_allclose(loss.numpy(), 100 - 1) + np.testing.assert_allclose(loss.numpy(), 100) - label = np.array([1]) + label = np.array([1, 0]) loss = F.nn.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 0.0)