loss.py 8.9 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 10
import functools

11 12
import numpy as np

13
from ..core.tensor.array_method import _reduce
14
from ..tensor import Tensor
15 16
from .elemwise import abs, log
from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
17 18
from .tensor import where

19 20 21
__all__ = [
    "l1_loss",
    "square_loss",
22
    "cross_entropy",
23 24 25 26
    "binary_cross_entropy",
    "hinge_loss",
]

27

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def _reduce_output(loss_fn):
    r"""
    Wrapper to apply canonical reductions to loss outputs.
    """

    @functools.wraps(loss_fn)
    def reduced_loss_fn(*args, reduction="mean", **kwargs):
        loss = loss_fn(*args, **kwargs)
        if reduction == "none":
            return loss
        elif reduction in ("mean", "sum"):
            return _reduce(reduction)(loss)
        else:
            raise ValueError("{} is not a valid value for reduction".format(reduction))

    return reduced_loss_fn


@_reduce_output
def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
48 49
    r"""
    Calculates the mean absolute error (MAE) between
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    each element in the pred :math:`x` and label :math:`y`.

    The mean absolute error can be described as:

    .. math:: \ell(x,y) = mean\left(L \right)

    where

    .. math::

        L = \{l_1,\dots,l_N\}, \quad
        l_n = \left| x_n - y_n \right|,

    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
    of :math:`N` elements each. :math:`N` is the batch size.

66 67
    :param pred: predicted result from model.
    :param label: ground truth to compare.
68
    :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
69
    :return: loss value.
70 71 72 73 74 75 76 77

    Examples:

    .. testcode::

        import numpy as np
        import megengine as mge
        import megengine.functional as F
78

79 80
        ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
        tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
81
        loss = F.nn.l1_loss(ipt, tgt)
82 83 84 85 86 87
        print(loss.numpy())

    Outputs:

    .. testoutput::

88
        2.75
89 90 91

    """
    diff = pred - label
92
    return abs(diff)
93 94


95 96
@_reduce_output
def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
97 98
    r"""
    Calculates the mean squared error (squared L2 norm) between
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    each element in the pred :math:`x` and label :math:`y`.

    The mean squared error can be described as:

    .. math:: \ell(x, y) = mean\left( L \right)

    where

    .. math::

        L = \{l_1,\dots,l_N\}, \quad
        l_n = \left( x_n - y_n \right)^2,

    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
    of :math:`N` elements each. :math:`N` is the batch size.

115 116
    :param pred: predicted result from model.
    :param label: ground truth to compare.
117
    :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
118
    :return: loss value.
119 120 121

    Shape:
        - pred: :math:`(N, *)` where :math:`*` means any number of additional
122 123 124 125 126 127 128 129 130 131 132 133 134
          dimensions.
        - label: :math:`(N, *)`. Same shape as ``pred``.

    Examples:

    .. testcode::

        import numpy as np
        import megengine as mge
        import megengine.functional as F

        ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
        tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
135
        loss = F.nn.square_loss(ipt, tgt)
136 137 138 139 140 141
        print(loss.numpy())

    Outputs:

    .. testoutput::

142
        9.75
143 144 145

    """
    diff = pred - label
146
    return diff ** 2
147 148


149
@_reduce_output
150 151 152 153 154 155
def cross_entropy(
    pred: Tensor,
    label: Tensor,
    axis: int = 1,
    with_logits: bool = True,
    label_smooth: float = 0,
156
    reduction: str = "mean",
157
) -> Tensor:
158 159
    r"""
    Computes the multi-class cross entropy loss (using logits by default).
160

161 162
    By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
    class probabilities are given by softmax.
163 164 165 166 167 168 169 170

    It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.

    When using label smoothing, the label distribution is as follows:

    .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K

    where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
M
Megvii Engine Team 已提交
171
    k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
172

173 174 175
    :param pred: input tensor representing the predicted probability.
    :param label: input tensor representing the classification label.
    :param axis: an axis along which softmax will be applied. Default: 1
176
    :param with_logits: whether to apply softmax first. Default: True
177
    :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
178
    :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
179 180 181 182 183 184 185 186 187 188 189 190
    :return: loss value.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        data_shape = (1, 2)
        label_shape = (1, )
191
        pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
192
        label = tensor(np.ones(label_shape, dtype=np.int32))
193
        loss = F.nn.cross_entropy(pred, label)
194
        print(loss.numpy().round(decimals=4))
195 196 197 198 199

    Outputs:

    .. testoutput::

200
        0.6931
201

202 203 204 205 206 207 208 209
    """
    n0 = pred.ndim
    n1 = label.ndim
    assert n0 == n1 + 1, (
        "target ndim must be one less than input ndim; input_ndim={} "
        "target_ndim={}".format(n0, n1)
    )

210 211 212
    ls = label_smooth

    if with_logits:
213 214
        logZ = logsumexp(pred, axis)
        primary_term = indexing_one_hot(pred, label, axis)
215 216
    else:
        logZ = 0
217
        primary_term = log(indexing_one_hot(pred, label, axis))
218 219
    if ls is None or type(ls) in (int, float) and ls == 0:
        return logZ - primary_term
220 221
    if not with_logits:
        pred = log(pred)
222
    return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
223 224


225
@_reduce_output
226
def binary_cross_entropy(
227
    pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
228
) -> Tensor:
229 230
    r"""
    Computes the binary cross entropy loss (using logits by default).
231

232 233
    By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
    class probabilities are given by sigmoid.
234

M
Megvii Engine Team 已提交
235
    :param pred: `(N, *)`, where `*` means any number of additional dimensions.
236
    :param label: `(N, *)`, same shape as the input.
237
    :param with_logits: bool, whether to apply sigmoid first. Default: True
238
    :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
239
    :return: loss value.
240

241 242 243 244 245 246 247 248
    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

249
        pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
250
        label = tensor(np.ones((1, 2), dtype=np.float32))
251
        loss = F.nn.binary_cross_entropy(pred, label)
252
        print(loss.numpy().round(decimals=4))
253 254

    Outputs:
255

256 257
    .. testoutput::

258
        0.6931
259 260

    """
261
    if not with_logits:
262
        return -(label * log(pred) + (1 - label) * log(1 - pred))
263 264
    # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
    # hopefully the backend would optimize this
265
    return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
266 267


268 269 270 271
@_reduce_output
def hinge_loss(
    pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
) -> Tensor:
272 273
    r"""
    Caculates the hinge loss which is often used in SVM.
274 275 276

    The hinge loss can be described as:

277
    .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
278

279 280 281
    :param pred: input tensor representing the predicted probability, shape is `(N, C)`.
    :param label: input tensor representing the binary classification label, shape is `(N, C)`.
    :param norm: specify the norm to caculate the loss, should be "L1" or "L2".
282
    :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
283
    :return: loss value.
284 285 286 287 288 289 290 291 292 293

    Examples:

    .. testcode::

        from megengine import tensor
        import megengine.functional as F

        pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
        label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
294
        loss = F.nn.hinge_loss(pred, label)
295 296 297 298 299 300
        print(loss.numpy())

    Outputs:

    .. testoutput::

301
        1.5
302 303

    """
304
    norm = norm.upper()
305 306 307 308
    assert norm in ["L1", "L2"], "norm must be L1 or L2"
    # Converts binary labels to -1/1 labels.
    loss = relu(1.0 - pred * label)
    if norm == "L1":
309
        return loss.sum(axis=1)
310
    else:
311
        return (loss ** 2).sum(axis=1)