diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index e2a1484508ad52d5d9bc4ec16a7768bb3a4e4208..d62e240f4871c199b6c1f5b1d7530be1443d8b49 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -6,8 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import functools + import numpy as np +from ..core.tensor.array_method import _reduce from ..tensor import Tensor from .elemwise import abs, log from .nn import indexing_one_hot, logsigmoid, logsumexp, relu @@ -22,7 +25,26 @@ __all__ = [ ] -def l1_loss(pred: Tensor, label: Tensor) -> Tensor: +def _reduce_output(loss_fn): + r""" + Wrapper to apply canonical reductions to loss outputs. + """ + + @functools.wraps(loss_fn) + def reduced_loss_fn(*args, reduction="mean", **kwargs): + loss = loss_fn(*args, **kwargs) + if reduction == "none": + return loss + elif reduction in ("mean", "sum"): + return _reduce(reduction)(loss) + else: + raise ValueError("{} is not a valid value for reduction".format(reduction)) + + return reduced_loss_fn + + +@_reduce_output +def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: r""" Calculates the mean absolute error (MAE) between each element in the pred :math:`x` and label :math:`y`. @@ -43,6 +65,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: :param pred: predicted result from model. :param label: ground truth to compare. + :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' :return: loss value. Examples: @@ -66,10 +89,11 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: """ diff = pred - label - return abs(diff).mean() + return abs(diff) -def square_loss(pred: Tensor, label: Tensor) -> Tensor: +@_reduce_output +def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: r""" Calculates the mean squared error (squared L2 norm) between each element in the pred :math:`x` and label :math:`y`. @@ -90,6 +114,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: :param pred: predicted result from model. :param label: ground truth to compare. + :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' :return: loss value. Shape: @@ -118,15 +143,17 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: """ diff = pred - label - return (diff ** 2).mean() + return diff ** 2 +@_reduce_output def cross_entropy( pred: Tensor, label: Tensor, axis: int = 1, with_logits: bool = True, label_smooth: float = 0, + reduction: str = "mean", ) -> Tensor: r""" Computes the multi-class cross entropy loss (using logits by default). @@ -148,6 +175,7 @@ def cross_entropy( :param axis: an axis along which softmax will be applied. Default: 1 :param with_logits: whether to apply softmax first. Default: True :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 + :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' :return: loss value. Examples: @@ -182,20 +210,21 @@ def cross_entropy( ls = label_smooth if with_logits: - logZ = logsumexp(pred, axis).mean() - primary_term = indexing_one_hot(pred, label, axis).mean() + logZ = logsumexp(pred, axis) + primary_term = indexing_one_hot(pred, label, axis) else: logZ = 0 - primary_term = log(indexing_one_hot(pred, label, axis)).mean() + primary_term = log(indexing_one_hot(pred, label, axis)) if ls is None or type(ls) in (int, float) and ls == 0: return logZ - primary_term if not with_logits: pred = log(pred) - return logZ - ls * pred.mean() - (1 - ls) * primary_term + return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term +@_reduce_output def binary_cross_entropy( - pred: Tensor, label: Tensor, with_logits: bool = True + pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean", ) -> Tensor: r""" Computes the binary cross entropy loss (using logits by default). @@ -206,6 +235,7 @@ def binary_cross_entropy( :param pred: `(N, *)`, where `*` means any number of additional dimensions. :param label: `(N, *)`, same shape as the input. :param with_logits: bool, whether to apply sigmoid first. Default: True + :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' :return: loss value. Examples: @@ -229,13 +259,16 @@ def binary_cross_entropy( """ if not with_logits: - return -(label * log(pred) + (1 - label) * log(1 - pred)).mean() + return -(label * log(pred) + (1 - label) * log(1 - pred)) # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression # hopefully the backend would optimize this - return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean() + return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)) -def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: +@_reduce_output +def hinge_loss( + pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean" +) -> Tensor: r""" Caculates the hinge loss which is often used in SVM. @@ -246,6 +279,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: :param pred: input tensor representing the predicted probability, shape is `(N, C)`. :param label: input tensor representing the binary classification label, shape is `(N, C)`. :param norm: specify the norm to caculate the loss, should be "L1" or "L2". + :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' :return: loss value. Examples: @@ -272,6 +306,6 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: # Converts binary labels to -1/1 labels. loss = relu(1.0 - pred * label) if norm == "L1": - return loss.sum(axis=1).mean() + return loss.sum(axis=1) else: - return (loss ** 2).sum(axis=1).mean() + return (loss ** 2).sum(axis=1) diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index a1ee43e4fa1b54ce7993a755c86a2f313457543e..9a7255523a0248ca7ac605f65a77a0921a384bfb 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import pytest import megengine.functional as F from megengine import tensor @@ -43,3 +44,38 @@ def test_cross_entropy(): l_ref = ref(x, y) l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) np.testing.assert_allclose(l.numpy(), l_ref) + + +def test_cross_entropy_reduction(): + logits = np.random.randn(16, 10) + label = np.random.randint(10, size=[16]) + logits = tensor(logits, dtype="float32") + label = tensor(label, dtype="int32") + + perm = np.random.permutation(16) + logits_perm = tensor(logits[perm], dtype="float32") + label_perm = tensor(label[perm], dtype="int32") + + loss = F.nn.cross_entropy(logits, label, reduction="none") + loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none") + np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy()) + + loss_sum = F.nn.cross_entropy(logits, label, reduction="sum") + np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7) + + loss_mean = F.nn.cross_entropy(logits, label, reduction="mean") + np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16) + + loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1) + loss_ls_none_reduce = F.nn.cross_entropy( + logits, label, reduction="none", label_smooth=0.1 + ) + np.testing.assert_allclose( + loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7 + ) + + with pytest.raises(ValueError): + F.nn.cross_entropy(logits, label, reduction="MEAN") + + with pytest.raises(ValueError): + F.nn.cross_entropy(logits, label, reduction="max")