提交 f7d8b516 编写于 作者: M Megvii Engine Team

feat(mge/functional): add smooth l1 loss

GitOrigin-RevId: c1437788d732e55ca3f99557c8049d556f4d2b67
上级 3c49d1d3
......@@ -46,6 +46,7 @@ from .loss import (
hinge_loss,
l1_loss,
nll_loss,
smooth_l1_loss,
square_loss,
triplet_margin_loss,
)
......
......@@ -340,3 +340,52 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
return loss.sum(axis=1).mean()
else:
return (loss ** 2).sum(axis=1).mean()
def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor:
r"""
Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`.
The smooth l1 loss can be described as:
.. math::
\text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i}
where :math:`l_{i}` is given by:
.. math::
l_{i} =
\begin{cases}
0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
|x_i - y_i| - 0.5, & \text{otherwise }
\end{cases}
:param pred: The predicted result from model.
:param label: The ground truth to compare.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]])
loss = F.smooth_l1_loss(pred, label)
print(loss.numpy())
Outputs:
.. testoutput::
[0.5608334]
"""
diff = abs(pred - label)
l2_loss = 0.5 * (diff ** 2)
l1_loss = diff - 0.5
mask = diff < 1
loss = where(mask, l2_loss, l1_loss)
return loss.mean()
......@@ -362,6 +362,19 @@ def test_hinge_loss():
opr_test(cases, hinge_loss_with_l2_norm)
def test_smooth_l1_loss():
np.random.seed(123)
cases = []
for shape in [(2, 2), (2, 3)]:
data = np.random.uniform(size=shape).astype(np.float32)
label = np.random.uniform(size=shape).astype(np.float32)
diff = np.abs(data - label)
expect = np.where(diff < 1, 0.5 * diff ** 2, diff - 0.5).mean()
cases.append({"input": [data, label], "output": tensor(expect)})
opr_test(cases, F.smooth_l1_loss)
@pytest.mark.skip
def test_conv_bias():
inp_scale = 0.01
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册