未验证 提交 7ecbc465 编写于 作者: F FlyingQianMM 提交者: GitHub

reimplement paddle.nn.functional.sigmoid_focal_loss (#27748)

* reimplement paddle.nn.functional.sigmoid_focal_loss. test=develop

* fix reduction error message. test=develop

* fix exp. test=develop

* reset the shape of logit. test=develop

* delete disable_static in example. test=develop
上级 ec7d11a4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import unittest
from op_test import OpTest
from test_sigmoid_focal_loss_op import sigmoid_focal_loss_forward
def call_sfl_functional(logit,
label,
normalizer,
alpha=0.25,
gamma=2.0,
reduction='sum'):
res = paddle.nn.functional.sigmoid_focal_loss(
logit, label, normalizer, alpha=alpha, gamma=gamma, reduction=reduction)
return res
def test_static(place,
logit_np,
label_np,
normalizer_np,
alpha=0.25,
gamma=2.0,
reduction='sum'):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
logit = paddle.data(name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
feed_dict = {"logit": logit_np, "label": label_np}
normalizer = None
if normalizer_np is not None:
normalizer = paddle.data(
name='normalizer', shape=normalizer_np.shape, dtype='float64')
feed_dict["normalizer"] = normalizer_np
res = call_sfl_functional(logit, label, normalizer, alpha, gamma,
reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
logit_np,
label_np,
normalizer_np,
alpha=0.25,
gamma=2.0,
reduction='sum'):
paddle.disable_static()
logit = paddle.to_tensor(logit_np)
label = paddle.to_tensor(label_np)
normalizer = None
if normalizer_np is not None:
normalizer = paddle.to_tensor(normalizer_np)
dy_res = call_sfl_functional(logit, label, normalizer, alpha, gamma,
reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_sigmoid_focal_loss(logit_np,
label_np,
normalizer_np,
alpha=0.25,
gamma=2.0,
reduction='sum'):
loss = np.maximum(
logit_np,
0) - logit_np * label_np + np.log(1 + np.exp(-np.abs(logit_np)))
pred = 1 / (1 + np.exp(-logit_np))
p_t = pred * label_np + (1 - pred) * (1 - label_np)
if alpha is not None:
alpha_t = alpha * label_np + (1 - alpha) * (1 - label_np)
loss = alpha_t * loss
if gamma is not None:
loss = loss * ((1 - p_t)**gamma)
if normalizer_np is not None:
loss = loss / normalizer_np
if reduction == 'mean':
loss = np.mean(loss)
elif reduction == 'sum':
loss = np.sum(loss)
return loss
class TestSigmoidFocalLoss(unittest.TestCase):
def test_SigmoidFocalLoss(self):
logit_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(
0, 2, size=(2, 3, 4, 10)).astype(np.float64)
normalizer_nps = [
np.asarray(
[np.sum(label_np > 0)], dtype=label_np.dtype), None
]
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
alphas = [0.25, 0.5]
gammas = [3, 0.]
for place in places:
for reduction in reductions:
for alpha in alphas:
for gamma in gammas:
for normalizer_np in normalizer_nps:
static_result = test_static(place, logit_np,
label_np, normalizer_np,
alpha, gamma, reduction)
dy_result = test_dygraph(place, logit_np, label_np,
normalizer_np, alpha,
gamma, reduction)
expected = calc_sigmoid_focal_loss(
logit_np, label_np, normalizer_np, alpha, gamma,
reduction)
self.assertTrue(
np.allclose(static_result, expected))
self.assertTrue(
np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def test_SigmoidFocalLoss_error(self):
paddle.disable_static()
logit = paddle.to_tensor([[0.97], [0.91], [0.03]], dtype='float32')
label = paddle.to_tensor([[1.0], [1.0], [0.0]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.sigmoid_focal_loss,
logit=logit,
label=label,
normalizer=None,
reduction="unsupport reduction")
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -32,7 +32,6 @@ from ...fluid.layers import npair_loss #DEFINE_ALIAS
from ...fluid.layers import rank_loss #DEFINE_ALIAS
from ...fluid.layers import reshape
from ...fluid.layers import sigmoid_cross_entropy_with_logits #DEFINE_ALIAS
from ...fluid.layers import sigmoid_focal_loss #DEFINE_ALIAS
from ...fluid.layers import smooth_l1 #DEFINE_ALIAS
from ...fluid.layers import softmax_with_cross_entropy #DEFINE_ALIAS
from ...fluid.layers import square_error_cost #DEFINE_ALIAS
......@@ -1151,3 +1150,165 @@ def cross_entropy(input,
out = reshape(out, shape=out_shape)
return out
def sigmoid_focal_loss(logit,
label,
normalizer=None,
alpha=0.25,
gamma=2.0,
reduction='sum',
name=None):
"""
`Focal Loss <https://arxiv.org/abs/1708.02002>`_ is proposed to address the
foreground-background class imbalance for classification tasks. It down-weights
easily-classified examples and thus focuses training on hard examples. For example,
it is used in one-stage object detection where the foreground-background class
imbalance is extremely high.
This operator measures focal loss function as follows:
.. math::
Out = -Labels * alpha * {(1 - \\sigma(Logit))}^{gamma}\\log(\\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\\sigma(Logit)}^{gamma}\\log(1 - \\sigma(Logit))
We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\exp(-Logit)}`.
Then, if :attr:`normalizer` is not None, this operator divides the
normalizer tensor on the loss `Out`:
.. math::
Out = \\frac{Out}{normalizer}
Finally, this operator applies reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target ``label`` is 0 for the negative class and is 1 for the positive class.
Args:
logit (Tensor): The input logit tensor. The shape is [N, *], where N is batch_size,
`*` means any number of additional dimensions. The ``logit`` is usually the
output of a convolution layer. Available dtype is float32, float64.
label (Tensor): The target label tensor with the same shape as
``logit``. The target label whose value should be numbers between 0 and 1.
Available dtype is float32, float64.
normalizer (Tensor, optional): The number normalizes the focal loss. It has to be
a 1-D Tensor whose shape is `[1, ]`. The data type is float32, float64.
For object detection task, it is the the number of positive samples.
If set to None, the focal loss will not be normalized. Default is None.
alpha(int|float, optional): Hyper-parameter to balance the positive and negative example,
it should be between 0 and 1. Default value is set to 0.25.
gamma(int|float, optional): Hyper-parameter to modulate the easy and hard examples.
Default value is set to 2.0.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'sum'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as ``logit``. The same dtype as ``logit`` tensor.
Examples:
.. code-block:: python
import paddle
logit = paddle.to_tensor([[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], dtype='float32')
label = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32')
one = paddle.to_tensor([1.], dtype='float32')
fg_label = paddle.greater_equal(label, one)
fg_num = paddle.reduce_sum(paddle.cast(fg_label, dtype='float32'))
output = paddle.nn.functional.sigmoid_focal_loss(logit, label, normalizer=fg_num)
print(output.numpy()) # [0.65782464]
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in sigmoid_focal_loss "
"should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
% reduction)
if normalizer is not None:
fluid.data_feeder.check_variable_and_dtype(normalizer, 'normalizer',
['float32', 'float64'],
'sigmoid_focal_loss')
normalizer_shape = list(normalizer.shape)
normalizer_dims = len(normalizer_shape)
if normalizer_dims > 1:
raise ValueError(
"Expected one dimension of normalizer in sigmoid_focal_loss but got {}.".
format(normalizer_dims))
if in_dygraph_mode():
one = _varbase_creator(dtype=logit.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape',
logit.shape)
loss = core.ops.sigmoid_cross_entropy_with_logits(logit, label)
pred = core.ops.sigmoid(logit)
p_t = core.ops.elementwise_add(
core.ops.elementwise_mul(pred, label),
core.ops.elementwise_mul(
core.ops.elementwise_sub(one, pred),
core.ops.elementwise_sub(one, label)))
alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype)
alpha_t = core.ops.elementwise_add(
core.ops.elementwise_mul(alpha, label),
core.ops.elementwise_mul(
core.ops.elementwise_sub(one, alpha),
core.ops.elementwise_sub(one, label)))
loss = core.ops.elementwise_mul(alpha_t, loss)
gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype)
gamma_t = core.ops.elementwise_pow(
core.ops.elementwise_sub(one, p_t), gamma)
loss = core.ops.elementwise_mul(gamma_t, loss)
if normalizer is not None:
loss = core.ops.elementwise_div(loss, normalizer)
if reduction == "sum":
return core.ops.reduce_sum(loss, 'reduce_all', True)
elif reduction == "mean":
return core.ops.mean(loss)
return loss
fluid.data_feeder.check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss')
bce_name = None
if reduction == 'none' and normalizer is None:
bce_name = name
loss = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, reduction='none', name=bce_name)
pred = fluid.layers.sigmoid(logit)
p_t = pred * label + (1 - pred) * (1 - label)
alpha_t = alpha * label + (1 - alpha) * (1 - label)
loss = paddle.multiply(alpha_t, loss)
gamma_t = paddle.pow((1 - p_t), gamma)
loss = paddle.multiply(gamma_t, loss)
if normalizer is not None:
normalizer_name = name if reduction == 'none' else None
loss = paddle.divide(loss, normalizer, name=normalizer_name)
if reduction == 'mean':
loss = paddle.mean(loss, name=name)
elif reduction == 'sum':
loss = paddle.sum(loss, name=name)
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册