From 7ecbc465c1d5bcacbd0b1fab91ae4db1292fe934 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Fri, 9 Oct 2020 13:21:11 +0800 Subject: [PATCH] reimplement paddle.nn.functional.sigmoid_focal_loss (#27748) * reimplement paddle.nn.functional.sigmoid_focal_loss. test=develop * fix reduction error message. test=develop * fix exp. test=develop * reset the shape of logit. test=develop * delete disable_static in example. test=develop --- .../unittests/test_sigmoid_focal_loss.py | 165 ++++++++++++++++++ python/paddle/nn/functional/loss.py | 163 ++++++++++++++++- 2 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py new file mode 100644 index 00000000000..71e119739e7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +from op_test import OpTest +from test_sigmoid_focal_loss_op import sigmoid_focal_loss_forward + + +def call_sfl_functional(logit, + label, + normalizer, + alpha=0.25, + gamma=2.0, + reduction='sum'): + res = paddle.nn.functional.sigmoid_focal_loss( + logit, label, normalizer, alpha=alpha, gamma=gamma, reduction=reduction) + return res + + +def test_static(place, + logit_np, + label_np, + normalizer_np, + alpha=0.25, + gamma=2.0, + reduction='sum'): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + logit = paddle.data(name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.data(name='label', shape=label_np.shape, dtype='float64') + feed_dict = {"logit": logit_np, "label": label_np} + + normalizer = None + if normalizer_np is not None: + normalizer = paddle.data( + name='normalizer', shape=normalizer_np.shape, dtype='float64') + feed_dict["normalizer"] = normalizer_np + + res = call_sfl_functional(logit, label, normalizer, alpha, gamma, + reduction) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + logit_np, + label_np, + normalizer_np, + alpha=0.25, + gamma=2.0, + reduction='sum'): + paddle.disable_static() + logit = paddle.to_tensor(logit_np) + label = paddle.to_tensor(label_np) + normalizer = None + if normalizer_np is not None: + normalizer = paddle.to_tensor(normalizer_np) + dy_res = call_sfl_functional(logit, label, normalizer, alpha, gamma, + reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_sigmoid_focal_loss(logit_np, + label_np, + normalizer_np, + alpha=0.25, + gamma=2.0, + reduction='sum'): + + loss = np.maximum( + logit_np, + 0) - logit_np * label_np + np.log(1 + np.exp(-np.abs(logit_np))) + + pred = 1 / (1 + np.exp(-logit_np)) + p_t = pred * label_np + (1 - pred) * (1 - label_np) + + if alpha is not None: + alpha_t = alpha * label_np + (1 - alpha) * (1 - label_np) + loss = alpha_t * loss + + if gamma is not None: + loss = loss * ((1 - p_t)**gamma) + + if normalizer_np is not None: + loss = loss / normalizer_np + + if reduction == 'mean': + loss = np.mean(loss) + elif reduction == 'sum': + loss = np.sum(loss) + + return loss + + +class TestSigmoidFocalLoss(unittest.TestCase): + def test_SigmoidFocalLoss(self): + logit_np = np.random.uniform( + 0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64) + label_np = np.random.randint( + 0, 2, size=(2, 3, 4, 10)).astype(np.float64) + normalizer_nps = [ + np.asarray( + [np.sum(label_np > 0)], dtype=label_np.dtype), None + ] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + alphas = [0.25, 0.5] + gammas = [3, 0.] + for place in places: + for reduction in reductions: + for alpha in alphas: + for gamma in gammas: + for normalizer_np in normalizer_nps: + static_result = test_static(place, logit_np, + label_np, normalizer_np, + alpha, gamma, reduction) + dy_result = test_dygraph(place, logit_np, label_np, + normalizer_np, alpha, + gamma, reduction) + expected = calc_sigmoid_focal_loss( + logit_np, label_np, normalizer_np, alpha, gamma, + reduction) + self.assertTrue( + np.allclose(static_result, expected)) + self.assertTrue( + np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_SigmoidFocalLoss_error(self): + paddle.disable_static() + logit = paddle.to_tensor([[0.97], [0.91], [0.03]], dtype='float32') + label = paddle.to_tensor([[1.0], [1.0], [0.0]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.sigmoid_focal_loss, + logit=logit, + label=label, + normalizer=None, + reduction="unsupport reduction") + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 05daf24ca24..c4b5606dddc 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -32,7 +32,6 @@ from ...fluid.layers import npair_loss #DEFINE_ALIAS from ...fluid.layers import rank_loss #DEFINE_ALIAS from ...fluid.layers import reshape from ...fluid.layers import sigmoid_cross_entropy_with_logits #DEFINE_ALIAS -from ...fluid.layers import sigmoid_focal_loss #DEFINE_ALIAS from ...fluid.layers import smooth_l1 #DEFINE_ALIAS from ...fluid.layers import softmax_with_cross_entropy #DEFINE_ALIAS from ...fluid.layers import square_error_cost #DEFINE_ALIAS @@ -1151,3 +1150,165 @@ def cross_entropy(input, out = reshape(out, shape=out_shape) return out + + +def sigmoid_focal_loss(logit, + label, + normalizer=None, + alpha=0.25, + gamma=2.0, + reduction='sum', + name=None): + """ + `Focal Loss `_ is proposed to address the + foreground-background class imbalance for classification tasks. It down-weights + easily-classified examples and thus focuses training on hard examples. For example, + it is used in one-stage object detection where the foreground-background class + imbalance is extremely high. + + This operator measures focal loss function as follows: + + .. math:: + Out = -Labels * alpha * {(1 - \\sigma(Logit))}^{gamma}\\log(\\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\\sigma(Logit)}^{gamma}\\log(1 - \\sigma(Logit)) + + We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\exp(-Logit)}`. + + Then, if :attr:`normalizer` is not None, this operator divides the + normalizer tensor on the loss `Out`: + + .. math:: + Out = \\frac{Out}{normalizer} + + Finally, this operator applies reduce operation on the loss. + If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. + + Note that the target ``label`` is 0 for the negative class and is 1 for the positive class. + + Args: + logit (Tensor): The input logit tensor. The shape is [N, *], where N is batch_size, + `*` means any number of additional dimensions. The ``logit`` is usually the + output of a convolution layer. Available dtype is float32, float64. + label (Tensor): The target label tensor with the same shape as + ``logit``. The target label whose value should be numbers between 0 and 1. + Available dtype is float32, float64. + normalizer (Tensor, optional): The number normalizes the focal loss. It has to be + a 1-D Tensor whose shape is `[1, ]`. The data type is float32, float64. + For object detection task, it is the the number of positive samples. + If set to None, the focal loss will not be normalized. Default is None. + alpha(int|float, optional): Hyper-parameter to balance the positive and negative example, + it should be between 0 and 1. Default value is set to 0.25. + gamma(int|float, optional): Hyper-parameter to modulate the easy and hard examples. + Default value is set to 2.0. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'sum'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as ``logit``. The same dtype as ``logit`` tensor. + + Examples: + + .. code-block:: python + + import paddle + + logit = paddle.to_tensor([[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], dtype='float32') + label = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32') + one = paddle.to_tensor([1.], dtype='float32') + fg_label = paddle.greater_equal(label, one) + fg_num = paddle.reduce_sum(paddle.cast(fg_label, dtype='float32')) + output = paddle.nn.functional.sigmoid_focal_loss(logit, label, normalizer=fg_num) + print(output.numpy()) # [0.65782464] + + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in sigmoid_focal_loss " + "should be 'sum', 'mean' or 'none', but received %s, which is not allowed." + % reduction) + + if normalizer is not None: + fluid.data_feeder.check_variable_and_dtype(normalizer, 'normalizer', + ['float32', 'float64'], + 'sigmoid_focal_loss') + normalizer_shape = list(normalizer.shape) + normalizer_dims = len(normalizer_shape) + if normalizer_dims > 1: + raise ValueError( + "Expected one dimension of normalizer in sigmoid_focal_loss but got {}.". + format(normalizer_dims)) + + if in_dygraph_mode(): + one = _varbase_creator(dtype=logit.dtype) + core.ops.fill_constant(one, 'value', + float(1.0), 'force_cpu', False, 'dtype', + one.dtype, 'str_value', '1.0', 'shape', + logit.shape) + loss = core.ops.sigmoid_cross_entropy_with_logits(logit, label) + pred = core.ops.sigmoid(logit) + p_t = core.ops.elementwise_add( + core.ops.elementwise_mul(pred, label), + core.ops.elementwise_mul( + core.ops.elementwise_sub(one, pred), + core.ops.elementwise_sub(one, label))) + + alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype) + alpha_t = core.ops.elementwise_add( + core.ops.elementwise_mul(alpha, label), + core.ops.elementwise_mul( + core.ops.elementwise_sub(one, alpha), + core.ops.elementwise_sub(one, label))) + loss = core.ops.elementwise_mul(alpha_t, loss) + + gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype) + gamma_t = core.ops.elementwise_pow( + core.ops.elementwise_sub(one, p_t), gamma) + loss = core.ops.elementwise_mul(gamma_t, loss) + + if normalizer is not None: + loss = core.ops.elementwise_div(loss, normalizer) + + if reduction == "sum": + return core.ops.reduce_sum(loss, 'reduce_all', True) + elif reduction == "mean": + return core.ops.mean(loss) + + return loss + + fluid.data_feeder.check_variable_and_dtype( + logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss') + + bce_name = None + if reduction == 'none' and normalizer is None: + bce_name = name + loss = paddle.nn.functional.binary_cross_entropy_with_logits( + logit, label, reduction='none', name=bce_name) + + pred = fluid.layers.sigmoid(logit) + p_t = pred * label + (1 - pred) * (1 - label) + + alpha_t = alpha * label + (1 - alpha) * (1 - label) + loss = paddle.multiply(alpha_t, loss) + + gamma_t = paddle.pow((1 - p_t), gamma) + loss = paddle.multiply(gamma_t, loss) + + if normalizer is not None: + normalizer_name = name if reduction == 'none' else None + loss = paddle.divide(loss, normalizer, name=normalizer_name) + + if reduction == 'mean': + loss = paddle.mean(loss, name=name) + elif reduction == 'sum': + loss = paddle.sum(loss, name=name) + + return loss -- GitLab