yolo_loss.py 14.1 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register
W
wangguanzhong 已提交
21 22 23 24
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
K
Kaipeng Deng 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

__all__ = ['YOLOv3Loss']


@register
class YOLOv3Loss(object):
    """
    Combined loss for YOLOv3 network

    Args:
        batch_size (int): training batch size
        ignore_thresh (float): threshold to ignore confidence loss
        label_smooth (bool): whether to use label smoothing
        use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
                                      instead of fluid.layers.yolov3_loss
    """
L
lxastro 已提交
41
    __inject__ = ['iou_loss', 'iou_aware_loss']
K
Kaipeng Deng 已提交
42 43 44 45 46 47
    __shared__ = ['use_fine_grained_loss']

    def __init__(self,
                 batch_size=8,
                 ignore_thresh=0.7,
                 label_smooth=True,
C
CodesFarmer 已提交
48
                 use_fine_grained_loss=False,
L
lxastro 已提交
49
                 iou_loss=None,
W
wangguanzhong 已提交
50 51 52 53
                 iou_aware_loss=None,
                 downsample=[32, 16, 8],
                 scale_x_y=1.,
                 match_score=False):
K
Kaipeng Deng 已提交
54 55 56 57
        self._batch_size = batch_size
        self._ignore_thresh = ignore_thresh
        self._label_smooth = label_smooth
        self._use_fine_grained_loss = use_fine_grained_loss
C
CodesFarmer 已提交
58
        self._iou_loss = iou_loss
L
lxastro 已提交
59
        self._iou_aware_loss = iou_aware_loss
W
wangguanzhong 已提交
60
        self.downsample = downsample
W
wangguanzhong 已提交
61
        self.scale_x_y = scale_x_y
W
wangguanzhong 已提交
62
        self.match_score = match_score
K
Kaipeng Deng 已提交
63 64 65 66 67 68 69 70 71 72

    def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
                 anchor_masks, mask_anchors, num_classes, prefix_name):
        if self._use_fine_grained_loss:
            return self._get_fine_grained_loss(
                outputs, targets, gt_box, self._batch_size, num_classes,
                mask_anchors, self._ignore_thresh)
        else:
            losses = []
            for i, output in enumerate(outputs):
W
wangguanzhong 已提交
73 74
                scale_x_y = self.scale_x_y if not isinstance(
                    self.scale_x_y, Sequence) else self.scale_x_y[i]
K
Kaipeng Deng 已提交
75 76 77 78 79 80 81 82 83 84
                anchor_mask = anchor_masks[i]
                loss = fluid.layers.yolov3_loss(
                    x=output,
                    gt_box=gt_box,
                    gt_label=gt_label,
                    gt_score=gt_score,
                    anchors=anchors,
                    anchor_mask=anchor_mask,
                    class_num=num_classes,
                    ignore_thresh=self._ignore_thresh,
W
wangguanzhong 已提交
85
                    downsample_ratio=self.downsample[i],
K
Kaipeng Deng 已提交
86
                    use_label_smooth=self._label_smooth,
W
wangguanzhong 已提交
87
                    scale_x_y=scale_x_y,
K
Kaipeng Deng 已提交
88
                    name=prefix_name + "yolo_loss" + str(i))
W
wangguanzhong 已提交
89

K
Kaipeng Deng 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
                losses.append(fluid.layers.reduce_mean(loss))

            return {'loss': sum(losses)}

    def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size,
                               num_classes, mask_anchors, ignore_thresh):
        """
        Calculate fine grained YOLOv3 loss

        Args:
            outputs ([Variables]): List of Variables, output of backbone stages
            targets ([Variables]): List of Variables, The targets for yolo
                                   loss calculatation.
            gt_box (Variable): The ground-truth boudding boxes.
            batch_size (int): The training batch size
            num_classes (int): class num of dataset
            mask_anchors ([[float]]): list of anchors in each output layer
            ignore_thresh (float): prediction bbox overlap any gt_box greater
                                   than ignore_thresh, objectness loss will
                                   be ignored.

        Returns:
            Type: dict
                xy_loss (Variable): YOLOv3 (x, y) coordinates loss
                wh_loss (Variable): YOLOv3 (w, h) coordinates loss
                obj_loss (Variable): YOLOv3 objectness score loss
                cls_loss (Variable): YOLOv3 classification loss

        """

        assert len(outputs) == len(targets), \
            "YOLOv3 output layer number not equal target number"

L
lxastro 已提交
123 124 125 126 127
        loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
        if self._iou_loss is not None:
            loss_ious = []
        if self._iou_aware_loss is not None:
            loss_iou_awares = []
K
Kaipeng Deng 已提交
128 129
        for i, (output, target,
                anchors) in enumerate(zip(outputs, targets, mask_anchors)):
W
wangguanzhong 已提交
130
            downsample = self.downsample[i]
K
Kaipeng Deng 已提交
131
            an_num = len(anchors) // 2
L
lxastro 已提交
132 133
            if self._iou_aware_loss is not None:
                ioup, output = self._split_ioup(output, an_num, num_classes)
K
Kaipeng Deng 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
            x, y, w, h, obj, cls = self._split_output(output, an_num,
                                                      num_classes)
            tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)

            tscale_tobj = tscale * tobj
            loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
                x, tx) * tscale_tobj
            loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
            loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
                y, ty) * tscale_tobj
            loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
            # NOTE: we refined loss function of (w, h) as L1Loss
            loss_w = fluid.layers.abs(w - tw) * tscale_tobj
            loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
            loss_h = fluid.layers.abs(h - th) * tscale_tobj
            loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
C
CodesFarmer 已提交
150
            if self._iou_loss is not None:
151 152
                loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
                                          downsample, self._batch_size)
C
CodesFarmer 已提交
153 154 155
                loss_iou = loss_iou * tscale_tobj
                loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
                loss_ious.append(fluid.layers.reduce_mean(loss_iou))
K
Kaipeng Deng 已提交
156

L
lxastro 已提交
157 158 159 160 161 162 163 164 165
            if self._iou_aware_loss is not None:
                loss_iou_aware = self._iou_aware_loss(
                    ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
                    self._batch_size)
                loss_iou_aware = loss_iou_aware * tobj
                loss_iou_aware = fluid.layers.reduce_sum(
                    loss_iou_aware, dim=[1, 2, 3])
                loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware))

W
wangguanzhong 已提交
166 167
            scale_x_y = self.scale_x_y if not isinstance(
                self.scale_x_y, Sequence) else self.scale_x_y[i]
K
Kaipeng Deng 已提交
168 169
            loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
                output, obj, tobj, gt_box, self._batch_size, anchors,
W
wangguanzhong 已提交
170
                num_classes, downsample, self._ignore_thresh, scale_x_y)
K
Kaipeng Deng 已提交
171 172 173 174 175 176 177 178 179 180 181

            loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
            loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
            loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])

            loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
            loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
            loss_objs.append(
                fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
            loss_clss.append(fluid.layers.reduce_mean(loss_cls))

C
CodesFarmer 已提交
182
        losses_all = {
K
Kaipeng Deng 已提交
183 184 185 186 187
            "loss_xy": fluid.layers.sum(loss_xys),
            "loss_wh": fluid.layers.sum(loss_whs),
            "loss_obj": fluid.layers.sum(loss_objs),
            "loss_cls": fluid.layers.sum(loss_clss),
        }
C
CodesFarmer 已提交
188 189
        if self._iou_loss is not None:
            losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
L
lxastro 已提交
190 191
        if self._iou_aware_loss is not None:
            losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
C
CodesFarmer 已提交
192
        return losses_all
K
Kaipeng Deng 已提交
193

L
lxastro 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207
    def _split_ioup(self, output, an_num, num_classes):
        """
        Split output feature map to output, predicted iou
        along channel dimension
        """
        ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
        ioup = fluid.layers.sigmoid(ioup)
        oriout = fluid.layers.slice(
            output,
            axes=[1],
            starts=[an_num],
            ends=[an_num * (num_classes + 6)])
        return (ioup, oriout)

K
Kaipeng Deng 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    def _split_output(self, output, an_num, num_classes):
        """
        Split output feature map to x, y, w, h, objectness, classification
        along channel dimension
        """
        x = fluid.layers.strided_slice(
            output,
            axes=[1],
            starts=[0],
            ends=[output.shape[1]],
            strides=[5 + num_classes])
        y = fluid.layers.strided_slice(
            output,
            axes=[1],
            starts=[1],
            ends=[output.shape[1]],
            strides=[5 + num_classes])
        w = fluid.layers.strided_slice(
            output,
            axes=[1],
            starts=[2],
            ends=[output.shape[1]],
            strides=[5 + num_classes])
        h = fluid.layers.strided_slice(
            output,
            axes=[1],
            starts=[3],
            ends=[output.shape[1]],
            strides=[5 + num_classes])
        obj = fluid.layers.strided_slice(
            output,
            axes=[1],
            starts=[4],
            ends=[output.shape[1]],
            strides=[5 + num_classes])
        clss = []
        stride = output.shape[1] // an_num
        for m in range(an_num):
            clss.append(
                fluid.layers.slice(
                    output,
                    axes=[1],
                    starts=[stride * m + 5],
                    ends=[stride * m + 5 + num_classes]))
        cls = fluid.layers.transpose(
            fluid.layers.stack(
                clss, axis=1), perm=[0, 1, 3, 4, 2])

        return (x, y, w, h, obj, cls)

    def _split_target(self, target):
        """
        split target to x, y, w, h, objectness, classification
        along dimension 2

        target is in shape [N, an_num, 6 + class_num, H, W]
        """
        tx = target[:, :, 0, :, :]
        ty = target[:, :, 1, :, :]
        tw = target[:, :, 2, :, :]
        th = target[:, :, 3, :, :]

        tscale = target[:, :, 4, :, :]
        tobj = target[:, :, 5, :, :]

        tcls = fluid.layers.transpose(
            target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
        tcls.stop_gradient = True

        return (tx, ty, tw, th, tscale, tobj, tcls)

    def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
W
wangguanzhong 已提交
280
                       num_classes, downsample, ignore_thresh, scale_x_y):
K
Kaipeng Deng 已提交
281 282 283 284 285
        # A prediction bbox overlap any gt_bbox over ignore_thresh, 
        # objectness loss will be ignored, process as follows:

        # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
        # NOTE: img_size is set as 1.0 to get noramlized pred bbox
W
wangguanzhong 已提交
286
        bbox, prob = fluid.layers.yolo_box(
K
Kaipeng Deng 已提交
287 288 289 290 291 292 293
            x=output,
            img_size=fluid.layers.ones(
                shape=[batch_size, 2], dtype="int32"),
            anchors=anchors,
            class_num=num_classes,
            conf_thresh=0.,
            downsample_ratio=downsample,
W
wangguanzhong 已提交
294 295
            clip_bbox=False,
            scale_x_y=scale_x_y)
K
Kaipeng Deng 已提交
296 297 298 299 300 301 302 303 304

        # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
        #    and gt bbox in each sample
        if batch_size > 1:
            preds = fluid.layers.split(bbox, batch_size, dim=0)
            gts = fluid.layers.split(gt_box, batch_size, dim=0)
        else:
            preds = [bbox]
            gts = [gt_box]
W
wangguanzhong 已提交
305
            probs = [prob]
K
Kaipeng Deng 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        ious = []
        for pred, gt in zip(preds, gts):

            def box_xywh2xyxy(box):
                x = box[:, 0]
                y = box[:, 1]
                w = box[:, 2]
                h = box[:, 3]
                return fluid.layers.stack(
                    [
                        x - w / 2.,
                        y - h / 2.,
                        x + w / 2.,
                        y + h / 2.,
                    ], axis=1)

            pred = fluid.layers.squeeze(pred, axes=[0])
            gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
            ious.append(fluid.layers.iou_similarity(pred, gt))

W
wangguanzhong 已提交
326
        iou = fluid.layers.stack(ious, axis=0)
K
Kaipeng Deng 已提交
327 328
        # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
        #    Get obj_mask by tobj(holds gt_score), calculate objectness loss
W
wangguanzhong 已提交
329

K
Kaipeng Deng 已提交
330 331
        max_iou = fluid.layers.reduce_max(iou, dim=-1)
        iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
W
wangguanzhong 已提交
332 333 334 335
        if self.match_score:
            max_prob = fluid.layers.reduce_max(prob, dim=-1)
            iou_mask = iou_mask * fluid.layers.cast(
                max_prob <= 0.25, dtype="float32")
K
Kaipeng Deng 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
        output_shape = fluid.layers.shape(output)
        an_num = len(anchors) // 2
        iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
                                                   output_shape[3]))
        iou_mask.stop_gradient = True

        # NOTE: tobj holds gt_score, obj_mask holds object existence mask
        obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
        obj_mask.stop_gradient = True

        # For positive objectness grids, objectness loss should be calculated
        # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
        loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
        loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
        loss_obj_neg = fluid.layers.reduce_sum(
            loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])

        return loss_obj_pos, loss_obj_neg