yolo_loss.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20 21 22
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
23 24

from ..utils import decode_yolo, xywh2xyxy, iou_similarity
25

K
Kaipeng Deng 已提交
26 27
__all__ = ['YOLOv3Loss']

28 29 30

@register
class YOLOv3Loss(nn.Layer):
31 32

    __inject__ = ['iou_loss', 'iou_aware_loss']
33 34 35 36 37 38
    __shared__ = ['num_classes']

    def __init__(self,
                 num_classes=80,
                 ignore_thresh=0.7,
                 label_smooth=False,
39 40 41 42
                 downsample=[32, 16, 8],
                 scale_x_y=1.,
                 iou_loss=None,
                 iou_aware_loss=None):
43
        super(YOLOv3Loss, self).__init__()
K
Kaipeng Deng 已提交
44
        self.num_classes = num_classes
45 46 47
        self.ignore_thresh = ignore_thresh
        self.label_smooth = label_smooth
        self.downsample = downsample
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        self.scale_x_y = scale_x_y
        self.iou_loss = iou_loss
        self.iou_aware_loss = iou_aware_loss

    def obj_loss(self, pbox, gbox, pobj, tobj, anchor, downsample):
        b, h, w, na = pbox.shape[:4]
        pbox = decode_yolo(pbox, anchor, downsample)
        pbox = pbox.reshape((b, -1, 4))
        pbox = xywh2xyxy(pbox)
        gbox = xywh2xyxy(gbox)

        iou = iou_similarity(pbox, gbox)
        iou.stop_gradient = True
        iou_max = iou.max(2)  # [N, M1]
        iou_mask = paddle.cast(iou_max <= self.ignore_thresh, dtype=pbox.dtype)
        iou_mask.stop_gradient = True

        pobj = pobj.reshape((b, -1))
        tobj = tobj.reshape((b, -1))
        obj_mask = paddle.cast(tobj > 0, dtype=pbox.dtype)
        obj_mask.stop_gradient = True

        loss_obj = F.binary_cross_entropy_with_logits(
            pobj, obj_mask, reduction='none')
        loss_obj_pos = (loss_obj * tobj)
        loss_obj_neg = (loss_obj * (1 - obj_mask) * iou_mask)
        return loss_obj_pos + loss_obj_neg

    def cls_loss(self, pcls, tcls):
        if self.label_smooth:
            delta = min(1. / self.num_classes, 1. / 40)
            pos, neg = 1 - delta, delta
            # 1 for positive, 0 for negative
            tcls = pos * paddle.cast(
                tcls > 0., dtype=tcls.dtype) + neg * paddle.cast(
                    tcls <= 0., dtype=tcls.dtype)

        loss_cls = F.binary_cross_entropy_with_logits(
            pcls, tcls, reduction='none')
        return loss_cls

    def yolov3_loss(self, x, t, gt_box, anchor, downsample, scale=1.,
                    eps=1e-10):
        na = len(anchor)
        b, c, h, w = x.shape
        no = c // na
        x = x.reshape((b, na, no, h, w)).transpose((0, 3, 4, 1, 2))

        xy, wh, obj = x[:, :, :, :, 0:2], x[:, :, :, :, 2:4], x[:, :, :, :, 4:5]
        if self.iou_aware_loss:
            ioup, pcls = x[:, :, :, :, 5:6], x[:, :, :, :, 6:]
        else:
            pcls = x[:, :, :, :, 5:]

        t = t.transpose((0, 3, 4, 1, 2))
        txy, twh, tscale = t[:, :, :, :, 0:2], t[:, :, :, :, 2:4], t[:, :, :, :,
                                                                     4:5]
        tobj, tcls = t[:, :, :, :, 5:6], t[:, :, :, :, 6:]

        tscale_obj = tscale * tobj
        loss = dict()
        if abs(scale - 1.) < eps:
            loss_xy = tscale_obj * F.binary_cross_entropy_with_logits(
                xy, txy, reduction='none')
        else:
            xy = scale * F.sigmoid(xy) - 0.5 * (scale - 1.)
            loss_xy = tscale_obj * paddle.abs(xy - txy)

        loss_xy = loss_xy.sum([1, 2, 3, 4]).mean()
        loss_wh = tscale_obj * paddle.abs(wh - twh)
        loss_wh = loss_wh.sum([1, 2, 3, 4]).mean()

        loss['loss_loc'] = loss_xy + loss_wh

        x[:, :, :, :, 0:2] = scale * F.sigmoid(x[:, :, :, :, 0:2]) - 0.5 * (
            scale - 1.)
        box, tbox = x[:, :, :, :, 0:4], t[:, :, :, :, 0:4]
        if self.iou_loss is not None:
            # box and tbox will not change though they are modified in self.iou_loss function, so no need to clone
            loss_iou = self.iou_loss(box, tbox, anchor, downsample, scale)
            loss_iou = loss_iou * tscale_obj.reshape((b, -1))
            loss_iou = loss_iou.sum(-1).mean()
            loss['loss_iou'] = loss_iou

        if self.iou_aware_loss is not None:
            # box and tbox will not change though they are modified in self.iou_aware_loss function, so no need to clone
            loss_iou_aware = self.iou_aware_loss(ioup, box, tbox, anchor,
                                                 downsample, scale)
            loss_iou_aware = loss_iou_aware * tobj.reshape((b, -1))
            loss_iou_aware = loss_iou_aware.sum(-1).mean()
            loss['loss_iou_aware'] = loss_iou_aware

        loss_obj = self.obj_loss(box, gt_box, obj, tobj, anchor, downsample)
        loss_obj = loss_obj.sum(-1).mean()
        loss['loss_obj'] = loss_obj
        loss_cls = self.cls_loss(pcls, tcls) * tobj
        loss_cls = loss_cls.sum([1, 2, 3, 4]).mean()
        loss['loss_cls'] = loss_cls
        return loss

    def forward(self, inputs, targets, anchors):
        np = len(inputs)
        gt_targets = [targets['target{}'.format(i)] for i in range(np)]
        gt_box = targets['gt_bbox']
        yolo_losses = dict()
        for x, t, anchor, downsample in zip(inputs, gt_targets, anchors,
                                            self.downsample):
            yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample)
            for k, v in yolo_loss.items():
                if k in yolo_losses:
                    yolo_losses[k] += v
                else:
                    yolo_losses[k] = v

        loss = 0
        for k, v in yolo_losses.items():
            loss += v

        yolo_losses['loss'] = loss
        return yolo_losses