提交 3c49d1d3 编写于 作者: M Megvii Engine Team

feat(mge/functional): add hinge loss

GitOrigin-RevId: 64c89c1f8c4e4ecbaf6892f9570c5d9db0027a1d
上级 dd8f3ffc
......@@ -43,6 +43,7 @@ from .loss import (
binary_cross_entropy,
cross_entropy,
cross_entropy_with_softmax,
hinge_loss,
l1_loss,
nll_loss,
square_loss,
......
......@@ -9,8 +9,9 @@
import megengine._internal as mgb
from ..core.tensor import Tensor
from .elemwise import abs, equal, log, maximum, power
from .elemwise import abs, equal, log, maximum, power, relu
from .nn import assert_equal, indexing_one_hot
from .tensor import where
from .utils import zero_grad
......@@ -297,3 +298,45 @@ def nll_loss(
loss = indexing_one_hot(pred, label, axis) * mask
return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
r"""
Caculate the hinge loss which is often used in SVMs.
The hinge loss can be described as:
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))
:param pred: The input tensor representing the predicted probability, shape is (N, C).
:param label: The input tensor representing the binary classification label, shape is (N, C).
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
label = tensor([[1, -1, -1], [-1, 1, 1]])
loss = F.hinge_loss(pred, label)
print(loss.numpy())
Outputs:
.. testoutput::
[1.5]
"""
assert norm in ["L1", "L2"], "norm must be L1 or L2"
# Converts binary labels to -1/1 labels.
loss = relu(1.0 - pred * label)
if norm == "L1":
return loss.sum(axis=1).mean()
else:
return (loss ** 2).sum(axis=1).mean()
......@@ -336,6 +336,32 @@ def test_binary_cross_entropy():
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)
def test_hinge_loss():
np.random.seed(123)
# case with L1 norm
cases = []
for shape in [(2, 2), (2, 3)]:
data = np.random.uniform(size=shape).astype(np.float32)
label = 2 * np.random.randint(0, 1, size=shape).astype(np.int32) - 1
expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
cases.append({"input": [data, label], "output": tensor(expect)})
opr_test(cases, F.hinge_loss)
# cases with L2 norm
cases = []
for shape in [(2, 2), (2, 3)]:
data = np.random.uniform(size=shape).astype(np.float32)
label = 2 * np.random.randint(0, 1, size=shape).astype(np.int32) - 1
expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
cases.append({"input": [data, label], "output": tensor(expect)})
def hinge_loss_with_l2_norm(pred, label):
return F.hinge_loss(pred, label, "L2")
opr_test(cases, hinge_loss_with_l2_norm)
@pytest.mark.skip
def test_conv_bias():
inp_scale = 0.01
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册