loss.py 5.1 KB
Newer Older
M
MegEngine Team 已提交
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.functional as F
10
from megengine.core import Tensor
M
MegEngine Team 已提交
11

12 13
from official.vision.detection import layers

M
MegEngine Team 已提交
14 15

def get_focal_loss(
16 17
    logits: Tensor,
    labels: Tensor,
M
MegEngine Team 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31
    ignore_label: int = -1,
    background: int = 0,
    alpha: float = 0.5,
    gamma: float = 0,
    norm_type: str = "fg",
) -> Tensor:
    r"""Focal Loss for Dense Object Detection:
    <https://arxiv.org/pdf/1708.02002.pdf>

    .. math::

        FL(p_t) = -\alpha_t(1-p_t)^\gamma \log(p_t)

    Args:
32 33 34 35
        logits (Tensor):
            the predicted logits with the shape of :math:`(B, A, C)`
        labels (Tensor):
            the assigned labels of boxes with shape of :math:`(B, A)`
M
MegEngine Team 已提交
36 37 38 39 40 41 42 43
        ignore_label (int):
            the value of ignore class. Default: -1
        background (int):
            the value of background class. Default: 0
        alpha (float):
            parameter to mitigate class imbalance. Default: 0.5
        gamma (float):
            parameter to mitigate easy/hard loss imbalance. Default: 0
44 45 46
        norm_type (str): current support "fg", "none":
            "fg": loss will be normalized by number of fore-ground samples
            "none": not norm
M
MegEngine Team 已提交
47 48 49 50

    Returns:
        the calculated focal loss.
    """
51
    class_range = F.arange(1, logits.shape[2] + 1)
52

53 54 55 56
    labels = F.add_axis(labels, axis=2)
    scores = F.sigmoid(logits)
    pos_part = (1 - scores) ** gamma * layers.logsigmoid(logits)
    neg_part = scores ** gamma * layers.logsigmoid(-logits)
57

58
    pos_loss = -(labels == class_range) * pos_part * alpha
59
    neg_loss = (
60
        -(labels != class_range) * (labels != ignore_label) * neg_part * (1 - alpha)
61
    )
62
    loss = (pos_loss + neg_loss).sum()
M
MegEngine Team 已提交
63 64

    if norm_type == "fg":
65 66
        fg_mask = (labels != background) * (labels != ignore_label)
        return loss / F.maximum(fg_mask.sum(), 1)
M
MegEngine Team 已提交
67
    elif norm_type == "none":
68
        return loss
M
MegEngine Team 已提交
69 70 71 72 73 74 75
    else:
        raise NotImplementedError


def get_smooth_l1_loss(
    pred_bbox: Tensor,
    gt_bbox: Tensor,
76
    labels: Tensor,
77
    beta: int = 1,
M
MegEngine Team 已提交
78 79 80 81 82 83 84 85 86 87 88
    background: int = 0,
    ignore_label: int = -1,
    norm_type: str = "fg",
) -> Tensor:
    r"""Smooth l1 loss used in RetinaNet.

    Args:
        pred_bbox (Tensor):
            the predicted bbox with the shape of :math:`(B, A, 4)`
        gt_bbox (Tensor):
            the ground-truth bbox with the shape of :math:`(B, A, 4)`
89 90
        labels (Tensor):
            the assigned labels of boxes with shape of :math:`(B, A)`
91
        beta (int):
M
MegEngine Team 已提交
92 93 94 95 96
            the parameter of smooth l1 loss. Default: 1
        background (int):
            the value of background class. Default: 0
        ignore_label (int):
            the value of ignore class. Default: -1
97 98 99 100
        norm_type (str): current support "fg", "all", "none":
            "fg": loss will be normalized by number of fore-ground samples
            "all": loss will be normalized by number of all samples
            "none": not norm
M
MegEngine Team 已提交
101 102 103 104 105
    Returns:
        the calculated smooth l1 loss.
    """
    pred_bbox = pred_bbox.reshape(-1, 4)
    gt_bbox = gt_bbox.reshape(-1, 4)
106
    labels = labels.reshape(-1)
M
MegEngine Team 已提交
107

108
    fg_mask = (labels != background) * (labels != ignore_label)
M
MegEngine Team 已提交
109

110 111
    loss = get_smooth_l1_base(pred_bbox, gt_bbox, beta)
    loss = (loss.sum(axis=1) * fg_mask).sum()
M
MegEngine Team 已提交
112
    if norm_type == "fg":
113
        loss = loss / F.maximum(fg_mask.sum(), 1)
M
MegEngine Team 已提交
114
    elif norm_type == "all":
115 116 117 118
        all_mask = labels != ignore_label
        loss = loss / F.maximum(all_mask.sum(), 1)
    elif norm_type == "none":
        return loss
M
MegEngine Team 已提交
119 120 121 122 123 124
    else:
        raise NotImplementedError

    return loss


125
def get_smooth_l1_base(pred_bbox: Tensor, gt_bbox: Tensor, beta: float) -> Tensor:
M
MegEngine Team 已提交
126 127 128 129 130 131 132
    r"""

    Args:
        pred_bbox (Tensor):
            the predicted bbox with the shape of :math:`(N, 4)`
        gt_bbox (Tensor):
            the ground-truth bbox with the shape of :math:`(N, 4)`
133
        beta (int):
M
MegEngine Team 已提交
134 135 136 137 138
            the parameter of smooth l1 loss.

    Returns:
        the calculated smooth l1 loss.
    """
139 140 141 142
    x = pred_bbox - gt_bbox
    abs_x = F.abs(x)
    if beta < 1e-5:
        loss = abs_x
M
MegEngine Team 已提交
143
    else:
144 145 146 147 148 149 150
        in_loss = 0.5 * x ** 2 / beta
        out_loss = abs_x - 0.5 * beta

        # FIXME: F.where cannot handle 0-shape tensor yet
        # loss = F.where(abs_x < beta, in_loss, out_loss)
        in_mask = abs_x < beta
        loss = in_loss * in_mask + out_loss * (1 - in_mask)
151 152 153
    return loss


154 155 156 157 158 159 160
def softmax_loss(scores: Tensor, labels: Tensor, ignore_label: int = -1) -> Tensor:
    max_scores = F.zero_grad(scores.max(axis=1, keepdims=True))
    scores -= max_scores
    log_prob = scores - F.log(F.exp(scores).sum(axis=1, keepdims=True))
    mask = labels != ignore_label
    vlabels = labels * mask
    loss = -(F.indexing_one_hot(log_prob, vlabels.astype("int32"), 1) * mask).sum()
161
    loss = loss / F.maximum(mask.sum(), 1)
M
MegEngine Team 已提交
162
    return loss