提交 aa632305 编写于 作者: M Megvii Engine Team

docs(functional): replace loss function testcode with doctest format

GitOrigin-RevId: 98224f0e5f04a2a54eb3ddbb0c011bcb20c7f2e0
上级 e8d0f9db
...@@ -66,29 +66,27 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: ...@@ -66,29 +66,27 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
Args: Args:
pred: predicted result from model. pred: predicted result from model.
label: ground truth to compare. label: ground truth to compare.
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
Returns: Returns:
loss value. loss value.
Examples: Shape:
* ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
.. testcode:: dimensions.
* ``label``: :math:`(N, *)`. Same shape as ``pred``.
import numpy as np
import megengine as mge
import megengine.functional as F
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
loss = F.nn.l1_loss(ipt, tgt)
print(loss.numpy())
Outputs: Examples:
.. testoutput:: >>> pred = Tensor([3, 3, 3, 3])
>>> label = Tensor([2, 8, 6, 1])
>>> F.nn.l1_loss(pred, label)
Tensor(2.75, device=xpux:0)
>>> F.nn.l1_loss(pred, label, reduction="none")
Tensor([1 5 3 2], dtype=int32, device=xpux:0)
>>> F.nn.l1_loss(pred, label, reduction="sum")
Tensor(11, dtype=int32, device=xpux:0)
2.75
""" """
diff = pred - label diff = pred - label
return abs(diff) return abs(diff)
...@@ -118,34 +116,27 @@ def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: ...@@ -118,34 +116,27 @@ def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
Args: Args:
pred: predicted result from model. pred: predicted result from model.
label: ground truth to compare. label: ground truth to compare.
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
Returns: Returns:
loss value. loss value.
Shape: Shape:
* pred: :math:`(N, *)` where :math:`*` means any number of additional * ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
dimensions. dimensions.
* label: :math:`(N, *)`. Same shape as ``pred``. * ``label``: :math:`(N, *)`. Same shape as ``pred``.
Examples: Examples:
.. testcode:: >>> pred = Tensor([3, 3, 3, 3])
>>> label = Tensor([2, 8, 6, 1])
import numpy as np >>> F.nn.square_loss(pred, label)
import megengine as mge Tensor(9.75, device=xpux:0)
import megengine.functional as F >>> F.nn.square_loss(pred, label, reduction="none")
Tensor([ 1. 25. 9. 4.], device=xpux:0)
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) >>> F.nn.square_loss(pred, label, reduction="sum")
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) Tensor(39.0, device=xpux:0)
loss = F.nn.square_loss(ipt, tgt)
print(loss.numpy())
Outputs:
.. testoutput::
9.75
""" """
diff = pred - label diff = pred - label
return diff ** 2 return diff ** 2
...@@ -162,11 +153,6 @@ def cross_entropy( ...@@ -162,11 +153,6 @@ def cross_entropy(
) -> Tensor: ) -> Tensor:
r"""Computes the multi-class cross entropy loss (using logits by default). r"""Computes the multi-class cross entropy loss (using logits by default).
By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by softmax.
It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
When using label smoothing, the label distribution is as follows: When using label smoothing, the label distribution is as follows:
.. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
...@@ -175,36 +161,39 @@ def cross_entropy( ...@@ -175,36 +161,39 @@ def cross_entropy(
k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
Args: Args:
pred: input tensor representing the predicted probability. pred: input tensor representing the predicted value.
label: input tensor representing the classification label. label: input tensor representing the classification label.
axis: an axis along which softmax will be applied. Default: 1 axis: an axis along which softmax will be applied. Default: 1
with_logits: whether to apply softmax first. Default: True with_logits: whether to apply softmax first. Default: True
label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
Returns: Returns:
loss value. loss value.
Examples: Examples:
.. testcode:: By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by softmax.
import numpy as np It has better numerical stability compared with sequential calls to
from megengine import tensor :func:`~.softmax` and :func:`~.cross_entropy`.
import megengine.functional as F
data_shape = (1, 2) >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
label_shape = (1, ) >>> label = Tensor([1., 1., 1.])
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape)) >>> F.nn.cross_entropy(pred, label) # doctest: +SKIP
label = tensor(np.ones(label_shape, dtype=np.int32)) Tensor(0.57976407, device=xpux:0)
loss = F.nn.cross_entropy(pred, label) >>> F.nn.cross_entropy(pred, label, reduction="none")
print(loss.numpy().round(decimals=4)) Tensor([0.3133 0.513 0.913 ], device=xpux:0)
Outputs: If the ``pred`` value has been probabilities, set ``with_logits`` to False:
.. testoutput:: >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
>>> label = Tensor([1., 1., 1.])
>>> F.nn.cross_entropy(pred, label, with_logits=False) # doctest: +SKIP
Tensor(0.5202159, device=xpux:0)
>>> F.nn.cross_entropy(pred, label, with_logits=False, reduction="none")
Tensor([0. 0.3567 1.204 ], device=xpux:0)
0.6931
""" """
n0 = pred.ndim n0 = pred.ndim
n1 = label.ndim n1 = label.ndim
...@@ -234,36 +223,38 @@ def binary_cross_entropy( ...@@ -234,36 +223,38 @@ def binary_cross_entropy(
) -> Tensor: ) -> Tensor:
r"""Computes the binary cross entropy loss (using logits by default). r"""Computes the binary cross entropy loss (using logits by default).
By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by sigmoid.
Args: Args:
pred: `(N, *)`, where `*` means any number of additional dimensions. pred: `(N, *)`, where `*` means any number of additional dimensions.
label: `(N, *)`, same shape as the input. label: `(N, *)`, same shape as the input.
with_logits: bool, whether to apply sigmoid first. Default: True with_logits: bool, whether to apply sigmoid first. Default: True
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
Returns: Returns:
loss value. loss value.
Examples: Examples:
.. testcode:: By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by softmax.
It has better numerical stability compared with sequential calls to
:func:`~.sigmoid` and :func:`~.binary_cross_entropy`.
import numpy as np >>> pred = Tensor([0.9, 0.7, 0.3])
from megengine import tensor >>> label = Tensor([1., 1., 1.])
import megengine.functional as F >>> F.nn.binary_cross_entropy(pred, label)
Tensor(0.4328984, device=xpux:0)
>>> F.nn.binary_cross_entropy(pred, label, reduction="none")
Tensor([0.3412 0.4032 0.5544], device=xpux:0)
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2)) If the ``pred`` value has been probabilities, set ``with_logits`` to False:
label = tensor(np.ones((1, 2), dtype=np.float32))
loss = F.nn.binary_cross_entropy(pred, label)
print(loss.numpy().round(decimals=4))
Outputs: >>> pred = Tensor([0.9, 0.7, 0.3])
>>> label = Tensor([1., 1., 1.])
>>> F.nn.binary_cross_entropy(pred, label, with_logits=False)
Tensor(0.5553361, device=xpux:0)
>>> F.nn.binary_cross_entropy(pred, label, with_logits=False, reduction="none")
Tensor([0.1054 0.3567 1.204 ], device=xpux:0)
.. testoutput::
0.6931
""" """
if not with_logits: if not with_logits:
return -(label * log(pred) + (1 - label) * log(1 - pred)) return -(label * log(pred) + (1 - label) * log(1 - pred))
...@@ -292,22 +283,15 @@ def hinge_loss( ...@@ -292,22 +283,15 @@ def hinge_loss(
loss value. loss value.
Examples: Examples:
>>> pred = Tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
>>> label = Tensor([[1, -1, -1], [-1, 1, 1]])
>>> F.nn.hinge_loss(pred, label)
Tensor(1.5, device=xpux:0)
>>> F.nn.hinge_loss(pred, label, reduction="none")
Tensor([2.1 0.9], device=xpux:0)
>>> F.nn.hinge_loss(pred, label, reduction="sum")
Tensor(3.0, device=xpux:0)
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
loss = F.nn.hinge_loss(pred, label)
print(loss.numpy())
Outputs:
.. testoutput::
1.5
""" """
norm = norm.upper() norm = norm.upper()
assert norm in ["L1", "L2"], "norm must be L1 or L2" assert norm in ["L1", "L2"], "norm must be L1 or L2"
...@@ -381,23 +365,12 @@ def ctc_loss( ...@@ -381,23 +365,12 @@ def ctc_loss(
Examples: Examples:
.. testcode:: >>> pred = Tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
>>> pred_lengths = Tensor([2, 2])
from megengine import tensor >>> label = Tensor([1, 1])
import megengine.functional as F >>> label_lengths = Tensor([1, 1])
>>> F.nn.ctc_loss(pred, pred_lengths, label, label_lengths)
pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]]) Tensor(0.1504417, device=xpux:0)
pred_length = tensor([2,2])
label = tensor([1,1])
label_lengths = tensor([1,1])
loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths)
print(loss.numpy())
Outputs:
.. testoutput::
0.1504417
""" """
T, N, C = pred.shape T, N, C = pred.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册