ctfocal_loss.py 2.3 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
M
Manuel Garcia 已提交
20

F
Feng Ni 已提交
21 22 23 24 25 26 27 28 29
from ppdet.core.workspace import register, serializable

__all__ = ['CTFocalLoss']


@register
@serializable
class CTFocalLoss(object):
    """
F
Feng Ni 已提交
30
    CTFocalLoss: CornerNet & CenterNet Focal Loss
F
Feng Ni 已提交
31
    Args:
F
Feng Ni 已提交
32 33
        loss_weight (float): loss weight
        gamma (float): gamma parameter for Focal Loss
F
Feng Ni 已提交
34 35 36 37 38 39 40 41 42 43
    """

    def __init__(self, loss_weight=1., gamma=2.0):
        self.loss_weight = loss_weight
        self.gamma = gamma

    def __call__(self, pred, target):
        """
        Calculate the loss
        Args:
F
Feng Ni 已提交
44 45
            pred (Tensor): heatmap prediction
            target (Tensor): target for positive samples
F
Feng Ni 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        Return:
            ct_focal_loss (Tensor): Focal Loss used in CornerNet & CenterNet.
                Note that the values in target are in [0, 1] since gaussian is
                used to reduce the punishment and we treat [0, 1) as neg example.
        """
        fg_map = paddle.cast(target == 1, 'float32')
        fg_map.stop_gradient = True
        bg_map = paddle.cast(target < 1, 'float32')
        bg_map.stop_gradient = True

        neg_weights = paddle.pow(1 - target, 4) * bg_map
        pos_loss = 0 - paddle.log(pred) * paddle.pow(1 - pred,
                                                     self.gamma) * fg_map
        neg_loss = 0 - paddle.log(1 - pred) * paddle.pow(
            pred, self.gamma) * neg_weights
        pos_loss = paddle.sum(pos_loss)
        neg_loss = paddle.sum(neg_loss)

        fg_num = paddle.sum(fg_map)
        ct_focal_loss = (pos_loss + neg_loss) / (
            fg_num + paddle.cast(fg_num == 0, 'float32'))
        return ct_focal_loss * self.loss_weight