diff --git a/configs/dcn/yolov3_r50vd_dcn_iouloss_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_iouloss_obj365_pretrained_coco.yml new file mode 100755 index 0000000000000000000000000000000000000000..bc067249c725ede6efd07948c99478ad6f1e2f7f --- /dev/null +++ b/configs/dcn/yolov3_r50vd_dcn_iouloss_obj365_pretrained_coco.yml @@ -0,0 +1,75 @@ +architecture: YOLOv3 +use_gpu: true +max_iters: 55000 +log_smooth_window: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_obj365_pretrained.tar +weights: output/yolov3_r50vd_dcn_iouloss_obj365_pretrained_coco/model_final +num_classes: 80 +use_fine_grained_loss: true + +YOLOv3: + backbone: ResNet + yolo_head: YOLOv3Head + use_fine_grained_loss: true + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 50 + feature_maps: [3, 4, 5] + variant: d + dcn_v2_stages: [5] + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + use_fine_grained_loss: true + iou_loss: IouLoss + +IouLoss: + loss_weight: 2.5 + max_height: 608 + max_width: 608 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 40000 + - 50000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: '../yolov3_reader.yml' diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 02e28ec5a8efe015ae3848ed46be464f14d1f14c..3179d3b994b297bc14864f2305acb2c2fd7038e9 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -18,8 +18,10 @@ from . import yolo_loss from . import smooth_l1_loss from . import giou_loss from . import diou_loss +from . import iou_loss from .yolo_loss import * from .smooth_l1_loss import * from .giou_loss import * from .diou_loss import * +from .iou_loss import * diff --git a/ppdet/modeling/losses/iou_loss.py b/ppdet/modeling/losses/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..253d00b092c32781b4228a8b298d54ebc1c1e7a4 --- /dev/null +++ b/ppdet/modeling/losses/iou_loss.py @@ -0,0 +1,172 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['IouLoss'] + + +@register +@serializable +class IouLoss(object): + """ + iou loss, see https://arxiv.org/abs/1908.03851 + loss = 1.0 - iou * iou + Args: + loss_weight (float): iou loss weight, default is 2.5 + max_height (int): max height of input to support random shape input + max_width (int): max width of input to support random shape input + """ + def __init__(self, + loss_weight=2.5, + max_height=608, + max_width=608): + self._loss_weight = loss_weight + self._MAX_HI = max_height + self._MAX_WI = max_width + + def __call__(self, x, y, w, h, tx, ty, tw, th, + anchors, downsample_ratio, batch_size, eps=1.e-10): + ''' + Args: + x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h + tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h + anchors ([float]): list of anchors for current output layer + downsample_ratio (float): the downsample ratio for current output layer + batch_size (int): training batch size + eps (float): the decimal to prevent the denominator eqaul zero + ''' + x1, y1, x2, y2 = self._bbox_transform(x, y, w, h, anchors, + downsample_ratio, batch_size, False) + x1g, y1g, x2g, y2g = self._bbox_transform(tx, ty, tw, th, + anchors, downsample_ratio, batch_size, True) + + x2 = fluid.layers.elementwise_max(x1, x2) + y2 = fluid.layers.elementwise_max(y1, y2) + + xkis1 = fluid.layers.elementwise_max(x1, x1g) + ykis1 = fluid.layers.elementwise_max(y1, y1g) + xkis2 = fluid.layers.elementwise_min(x2, x2g) + ykis2 = fluid.layers.elementwise_min(y2, y2g) + + xc1 = fluid.layers.elementwise_min(x1, x1g) + yc1 = fluid.layers.elementwise_min(y1, y1g) + xc2 = fluid.layers.elementwise_max(x2, x2g) + yc2 = fluid.layers.elementwise_max(y2, y2g) + + intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) + intsctk = intsctk * fluid.layers.greater_than( + xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + eps + iouk = intsctk / unionk + loss_iou = 1. - iouk * iouk + loss_iou = loss_iou * self._loss_weight + + return loss_iou + + def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio, batch_size, is_gt): + grid_x = int(self._MAX_WI / downsample_ratio) + grid_y = int(self._MAX_HI / downsample_ratio) + an_num = len(anchors) // 2 + + shape_fmp = fluid.layers.shape(dcx) + shape_fmp.stop_gradient = True + # generate the grid_w x grid_h center of feature map + idx_i = np.array([[i for i in range(grid_x)]]) + idx_j = np.array([[j for j in range(grid_y)]]).transpose() + gi_np = np.repeat(idx_i, grid_y, axis=0) + gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x]) + gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1]) + gj_np = np.repeat(idx_j, grid_x, axis=1) + gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x]) + gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1]) + gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32)) + gi = fluid.layers.crop(x=gi_max, shape=dcx) + gi.stop_gradient = True + gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32)) + gj = fluid.layers.crop(x=gj_max, shape=dcx) + gj.stop_gradient = True + + grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32") + grid_x_act.stop_gradient = True + grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32") + grid_y_act.stop_gradient = True + if is_gt: + cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act + cx.gradient = True + cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act + cy.gradient = True + else: + dcx_sig = fluid.layers.sigmoid(dcx) + cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act + dcy_sig = fluid.layers.sigmoid(dcy) + cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act + + anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0] + anchor_w_np = np.array(anchor_w_) + anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1]) + anchor_w_np = np.tile(anchor_w_np, reps=[batch_size, 1, grid_y, grid_x]) + anchor_w_max = self._create_tensor_from_numpy(anchor_w_np.astype(np.float32)) + anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx) + anchor_w.stop_gradient = True + anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1] + anchor_h_np = np.array(anchor_h_) + anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1]) + anchor_h_np = np.tile(anchor_h_np, reps=[batch_size, 1, grid_y, grid_x]) + anchor_h_max = self._create_tensor_from_numpy(anchor_h_np.astype(np.float32)) + anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx) + anchor_h.stop_gradient = True + # e^tw e^th + exp_dw = fluid.layers.exp(dw) + exp_dh = fluid.layers.exp(dh) + pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \ + (grid_x_act * downsample_ratio) + ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \ + (grid_y_act * downsample_ratio) + if is_gt: + exp_dw.stop_gradient = True + exp_dh.stop_gradient = True + pw.stop_gradient = True + ph.stop_gradient = True + + + x1 = cx - 0.5 * pw + y1 = cy - 0.5 * ph + x2 = cx + 0.5 * pw + y2 = cy + 0.5 * ph + if is_gt: + x1.stop_gradient = True + y1.stop_gradient = True + x2.stop_gradient = True + y2.stop_gradient = True + + return x1, y1, x2, y2 + + def _create_tensor_from_numpy(self, numpy_array): + paddle_array = fluid.layers.create_parameter( + attr=ParamAttr(), + shape=numpy_array.shape, + dtype=numpy_array.dtype, + default_initializer=NumpyArrayInitializer(numpy_array)) + paddle_array.stop_gradient = True + return paddle_array + diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index ad7e766a2fbc503934a395be0c0dec02eb14eed6..a0ffdfcfc6e31d7c0f49b36e0e5ad33628ca2cc0 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -34,17 +34,20 @@ class YOLOv3Loss(object): use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss instead of fluid.layers.yolov3_loss """ + __inject__ = ['iou_loss'] __shared__ = ['use_fine_grained_loss'] def __init__(self, batch_size=8, ignore_thresh=0.7, label_smooth=True, - use_fine_grained_loss=False): + use_fine_grained_loss=False, + iou_loss=None): self._batch_size = batch_size self._ignore_thresh = ignore_thresh self._label_smooth = label_smooth self._use_fine_grained_loss = use_fine_grained_loss + self._iou_loss = iou_loss def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors, anchor_masks, mask_anchors, num_classes, prefix_name): @@ -104,7 +107,10 @@ class YOLOv3Loss(object): "YOLOv3 output layer number not equal target number" downsample = 32 - loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + if self._iou_loss is None: + loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + else: + loss_xys, loss_whs, loss_ious, loss_objs, loss_clss = [], [], [], [], [] for i, (output, target, anchors) in enumerate(zip(outputs, targets, mask_anchors)): an_num = len(anchors) // 2 @@ -124,6 +130,12 @@ class YOLOv3Loss(object): loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) loss_h = fluid.layers.abs(h - th) * tscale_tobj loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + if self._iou_loss is not None: + loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, + anchors, downsample, self._batch_size) + loss_iou = loss_iou * tscale_tobj + loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) + loss_ious.append(fluid.layers.reduce_mean(loss_iou)) loss_obj_pos, loss_obj_neg = self._calc_obj_loss( output, obj, tobj, gt_box, self._batch_size, anchors, @@ -140,13 +152,15 @@ class YOLOv3Loss(object): loss_clss.append(fluid.layers.reduce_mean(loss_cls)) downsample //= 2 - - return { + losses_all = { "loss_xy": fluid.layers.sum(loss_xys), "loss_wh": fluid.layers.sum(loss_whs), "loss_obj": fluid.layers.sum(loss_objs), "loss_cls": fluid.layers.sum(loss_clss), } + if self._iou_loss is not None: + losses_all["loss_iou"] = fluid.layers.sum(loss_ious) + return losses_all def _split_output(self, output, an_num, num_classes): """ diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index dffa1a6706e58d03b29452b7aead86973a3e63eb..1a4b5f4ec66c9e62ba580e093c2a359bcb11bb43 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -155,14 +155,12 @@ def DropBlock(input, block_size, keep_prob, is_test): mask = 1.0 - mask_flag elem_numel = fluid.layers.reduce_prod(input_shape) - elem_numel = fluid.layers.cast(elem_numel, dtype="float32") - elem_numel_tmp = fluid.layers.reshape(elem_numel, [1, 1, 1, 1]) - elem_numel_m = fluid.layers.expand_as(elem_numel_tmp, input) + elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32") + elem_numel_m.stop_gradient = True elem_sum = fluid.layers.reduce_sum(mask) - elem_sum_tmp = fluid.layers.cast(elem_sum, dtype="float32") - elem_sum_tmp = fluid.layers.reshape(elem_sum_tmp, [1, 1, 1, 1]) - elem_sum_m = fluid.layers.expand_as(elem_sum_tmp, input) + elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32") + elem_sum_m.stop_gradient = True output = input * mask * elem_numel_m / elem_sum_m return output