From 6663bce86de453abd2165ebb0f9030b8a68201c8 Mon Sep 17 00:00:00 2001 From: lxastro Date: Mon, 13 Apr 2020 18:17:24 +0800 Subject: [PATCH] add iou aware module (#432) * add iou aware module * iouaware model zoo * fix copyright --- ...dcn_db_iouaware_obj365_pretrained_coco.yml | 84 +++++++++++++++++++ docs/MODEL_ZOO.md | 2 +- docs/MODEL_ZOO_cn.md | 2 +- docs/featured_model/YOLOv3_ENHANCEMENT.md | 5 +- ppdet/modeling/anchor_heads/iou_aware.py | 83 ++++++++++++++++++ ppdet/modeling/anchor_heads/yolo_head.py | 15 +++- ppdet/modeling/losses/__init__.py | 2 + ppdet/modeling/losses/iou_aware_loss.py | 75 +++++++++++++++++ ppdet/modeling/losses/iou_loss.py | 28 ++++++- ppdet/modeling/losses/yolo_loss.py | 41 +++++++-- 10 files changed, 323 insertions(+), 14 deletions(-) create mode 100755 configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml create mode 100644 ppdet/modeling/anchor_heads/iou_aware.py create mode 100644 ppdet/modeling/losses/iou_aware_loss.py diff --git a/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml new file mode 100755 index 000000000..6177aaac7 --- /dev/null +++ b/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml @@ -0,0 +1,84 @@ +architecture: YOLOv3 +use_gpu: true +max_iters: 85000 +log_smooth_window: 1 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar +weights: output/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco/model_final +num_classes: 80 +use_fine_grained_loss: true + +YOLOv3: + backbone: ResNet + yolo_head: YOLOv3Head + use_fine_grained_loss: true + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 50 + feature_maps: [3, 4, 5] + variant: d + dcn_v2_stages: [5] + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + iou_aware: true + iou_aware_factor: 0.4 + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + drop_block: true + +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + use_fine_grained_loss: true + iou_loss: IouLoss + iou_aware_loss: IouAwareLoss + +IouLoss: + loss_weight: 2.5 + max_height: 608 + max_width: 608 + +IouAwareLoss: + loss_weight: 1.0 + max_height: 608 + max_width: 608 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 55000 + - 75000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'yolov3_enhance_reader.yml' diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 5985ce526..c20b090ee 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -164,7 +164,7 @@ improved performance mainly by using L1 loss in bounding box width and height re randomly color distortion, randomly cropping, randomly expansion, randomly interpolation method, randomly flippling. YOLO v3 used randomly reshaped minibatch in training, inferences can be performed on different image sizes with the same model weights, and we provided evaluation results of image size 608/416/320 above. Deformable conv is added on stage 5 of backbone. -- YOLO v3 enhanced model improves the precision to 43.2 involved with deformable conv, dropblock and IoU loss. See more details in [YOLOv3_ENHANCEMENT](./featured_model/YOLOv3_ENHANCEMENT.md) +- YOLO v3 enhanced model improves the precision to 43.6 involved with deformable conv, dropblock, IoU loss and IoU aware. See more details in [YOLOv3_ENHANCEMENT](./featured_model/YOLOv3_ENHANCEMENT.md) ### RetinaNet diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md index 0deaf0ee1..df7ccf9e5 100644 --- a/docs/MODEL_ZOO_cn.md +++ b/docs/MODEL_ZOO_cn.md @@ -156,7 +156,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 - 上表中也提供了原论文[YOLOv3](https://arxiv.org/abs/1804.02767)中YOLOv3-DarkNet53的精度,我们的实现版本主要从在bounding box的宽度和高度回归上使用了L1损失,图像mixup和label smooth等方法优化了其精度。 - YOLO v3在8卡,总batch size为64下训练270轮。数据增强包括:mixup, 随机颜色失真,随机剪裁,随机扩张,随机插值法,随机翻转。YOLO v3在训练阶段对minibatch采用随机reshape,可以采用相同的模型测试不同尺寸图片,我们分别提供了尺寸为608/416/320大小的测试结果。deformable卷积作用在骨架网络5阶段。 -- YOLO v3增强版模型通过引入可变形卷积,dropblock和IoU loss,将精度进一步提升至43.2, 详情见[YOLOv3增强模型](./featured_model/YOLOv3_ENHANCEMENT.md) +- YOLO v3增强版模型通过引入可变形卷积,dropblock,IoU loss和Iou aware,将精度进一步提升至43.6, 详情见[YOLOv3增强模型](./featured_model/YOLOv3_ENHANCEMENT.md) ### RetinaNet diff --git a/docs/featured_model/YOLOv3_ENHANCEMENT.md b/docs/featured_model/YOLOv3_ENHANCEMENT.md index 20d746b6d..14ebf7d53 100644 --- a/docs/featured_model/YOLOv3_ENHANCEMENT.md +++ b/docs/featured_model/YOLOv3_ENHANCEMENT.md @@ -24,7 +24,9 @@ PaddleDetection实现版本中使用了 [Bag of Freebies for Training Object Det 4.Yolo v3作为一阶段检测网络,在定位精度上相比Faster RCNN,Cascade RCNN等网络结构有着其天然的劣势,增加[IoU Loss](https://arxiv.org/abs/1908.03851)分支,可以一定程度上提高BBox定位精度,缩小一阶段和两阶段检测网络的差距。 -5.使用[Object365数据集](https://www.objects365.org/download.html)训练得到的模型作为coco数据集上的预训练模型,Object365数据集包含约60万张图片以及365种类别,相比coco数据集进行预训练可以进一步提高YOLOv3的精度。 +5.增加[IoU Aware](https://arxiv.org/abs/1912.05992)分支,预测输出BBox和真实BBox的IoU,修正用于NMS的评分,可进一步提高YOLOV3的预测性能。 + +6.使用[Object365数据集](https://www.objects365.org/download.html)训练得到的模型作为coco数据集上的预训练模型,Object365数据集包含约60万张图片以及365种类别,相比coco数据集进行预训练可以进一步提高YOLOv3的精度。 ## 使用方法 @@ -46,3 +48,4 @@ python tools/train.py -c configs/dcn/yolov3_r50vd_dcn_iouloss_obj365_pretrained_ | YOLOv3 ResNet50_vd DCN | [Object365 pretrain](https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar) | 42.5 | 原生:74.4ms
tensorRT-FP32: 35.2ms | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_obj365_v2.tar) | | YOLOv3 ResNet50_vd DCN DropBlock | [Object365 pretrain](https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar) | 42.8 | 原生:74.4ms
tensorRT-FP32: 35.2ms | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_obj365_dropblock.tar) | | YOLOv3 ResNet50_vd DCN DropBlock IoULoss | [Object365 pretrain](https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar) | 43.2 | 原生:74.4ms
tensorRT-FP32: 35.2ms | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_obj365_dropblock_iouloss.tar) | +| YOLOv3 ResNet50_vd DCN DropBlock IoU-Aware | [Object365 pretrain](https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar) | 43.6 | 原生:74.4ms
tensorRT-FP32: 35.2ms | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.pdparams) | diff --git a/ppdet/modeling/anchor_heads/iou_aware.py b/ppdet/modeling/anchor_heads/iou_aware.py new file mode 100644 index 000000000..9a2c4ee41 --- /dev/null +++ b/ppdet/modeling/anchor_heads/iou_aware.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid + + +def _split_ioup(output, an_num, num_classes): + """ + Split new output feature map to output, predicted iou + along channel dimension + """ + ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num]) + ioup = fluid.layers.sigmoid(ioup) + + oriout = fluid.layers.slice( + output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)]) + + return (ioup, oriout) + + +def _de_sigmoid(x, eps=1e-7): + x = fluid.layers.clip(x, eps, 1 / eps) + x = fluid.layers.clip((1 / x - 1.0), eps, 1 / eps) + x = -fluid.layers.log(x) + return x + + +def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor): + """ + post process output objectness score + """ + tensors = [] + stride = output.shape[1] // an_num + for m in range(an_num): + tensors.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 0], + ends=[stride * m + 4])) + obj = fluid.layers.slice( + output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5]) + obj = fluid.layers.sigmoid(obj) + ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1]) + + new_obj = fluid.layers.pow(obj, ( + 1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor) + new_obj = _de_sigmoid(new_obj) + + tensors.append(new_obj) + + tensors.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 5], + ends=[stride * m + 5 + num_classes])) + + output = fluid.layers.concat(tensors, axis=1) + + return output + + +def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor): + ioup, output = _split_ioup(output, an_num, num_classes) + output = _postprocess_output(ioup, output, an_num, num_classes, + iou_aware_factor) + return output diff --git a/ppdet/modeling/anchor_heads/yolo_head.py b/ppdet/modeling/anchor_heads/yolo_head.py index b8142d66d..3ec2b981f 100644 --- a/ppdet/modeling/anchor_heads/yolo_head.py +++ b/ppdet/modeling/anchor_heads/yolo_head.py @@ -24,6 +24,7 @@ from ppdet.modeling.ops import MultiClassNMS from ppdet.modeling.losses.yolo_loss import YOLOv3Loss from ppdet.core.workspace import register from ppdet.modeling.ops import DropBlock +from .iou_aware import get_iou_aware_score __all__ = ['YOLOv3Head'] @@ -50,6 +51,8 @@ class YOLOv3Head(object): [59, 119], [116, 90], [156, 198], [373, 326]], anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], drop_block=False, + iou_aware=False, + iou_aware_factor=0.4, block_size=3, keep_prob=0.9, yolo_loss="YOLOv3Loss", @@ -68,6 +71,8 @@ class YOLOv3Head(object): self.nms = nms self.prefix_name = weight_prefix_name self.drop_block = drop_block + self.iou_aware = iou_aware + self.iou_aware_factor = iou_aware_factor self.block_size = block_size self.keep_prob = keep_prob if isinstance(nms, dict): @@ -220,7 +225,10 @@ class YOLOv3Head(object): name=self.prefix_name + "yolo_block.{}".format(i)) # out channel number = mask_num * (5 + class_num) - num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) + if self.iou_aware: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) + else: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) with fluid.name_scope('yolo_output'): block_out = fluid.layers.conv2d( input=tip, @@ -295,6 +303,11 @@ class YOLOv3Head(object): scores = [] downsample = 32 for i, output in enumerate(outputs): + if self.iou_aware: + output = get_iou_aware_score(output, + len(self.anchor_masks[i]), + self.num_classes, + self.iou_aware_factor) box, score = fluid.layers.yolo_box( x=output, img_size=im_size, diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 88f44c90f..d1b4dcca2 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -22,7 +22,9 @@ from . import iou_loss from . import balanced_l1_loss from . import fcos_loss from . import diou_loss_yolo +from . import iou_aware_loss +from .iou_aware_loss import * from .yolo_loss import * from .smooth_l1_loss import * from .giou_loss import * diff --git a/ppdet/modeling/losses/iou_aware_loss.py b/ppdet/modeling/losses/iou_aware_loss.py new file mode 100644 index 000000000..d4677eac3 --- /dev/null +++ b/ppdet/modeling/losses/iou_aware_loss.py @@ -0,0 +1,75 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer + +from paddle import fluid +from ppdet.core.workspace import register, serializable +from .iou_loss import IouLoss + +__all__ = ['IouAwareLoss'] + + +@register +@serializable +class IouAwareLoss(IouLoss): + """ + iou aware loss, see https://arxiv.org/abs/1912.05992 + Args: + loss_weight (float): iou aware loss weight, default is 1.0 + max_height (int): max height of input to support random shape input + max_width (int): max width of input to support random shape input + """ + + def __init__(self, loss_weight=1.0, max_height=608, max_width=608): + super(IouAwareLoss, self).__init__( + loss_weight=loss_weight, max_height=max_height, max_width=max_width) + + def __call__(self, + ioup, + x, + y, + w, + h, + tx, + ty, + tw, + th, + anchors, + downsample_ratio, + batch_size, + eps=1.e-10): + ''' + Args: + ioup ([Variables]): the predicted iou + x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h + tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h + anchors ([float]): list of anchors for current output layer + downsample_ratio (float): the downsample ratio for current output layer + batch_size (int): training batch size + eps (float): the decimal to prevent the denominator eqaul zero + ''' + + iouk = self._iou(x, y, w, h, tx, ty, tw, th, anchors, downsample_ratio, + batch_size, ioup, eps) + iouk.stop_gradient = True + + loss_iou_aware = fluid.layers.cross_entropy(ioup, iouk, soft_label=True) + loss_iou_aware = loss_iou_aware * self._loss_weight + return loss_iou_aware diff --git a/ppdet/modeling/losses/iou_loss.py b/ppdet/modeling/losses/iou_loss.py index 498ae607a..a61f101d7 100644 --- a/ppdet/modeling/losses/iou_loss.py +++ b/ppdet/modeling/losses/iou_loss.py @@ -54,6 +54,7 @@ class IouLoss(object): anchors, downsample_ratio, batch_size, + ioup=None, eps=1.e-10): ''' Args: @@ -64,6 +65,28 @@ class IouLoss(object): batch_size (int): training batch size eps (float): the decimal to prevent the denominator eqaul zero ''' + + iouk = self._iou(x, y, w, h, tx, ty, tw, th, anchors, downsample_ratio, + batch_size, ioup, eps) + loss_iou = 1. - iouk * iouk + loss_iou = loss_iou * self._loss_weight + + return loss_iou + + def _iou(self, + x, + y, + w, + h, + tx, + ty, + tw, + th, + anchors, + downsample_ratio, + batch_size, + ioup=None, + eps=1.e-10): x1, y1, x2, y2 = self._bbox_transform( x, y, w, h, anchors, downsample_ratio, batch_size, False) x1g, y1g, x2g, y2g = self._bbox_transform( @@ -83,10 +106,7 @@ class IouLoss(object): unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g ) - intsctk + eps iouk = intsctk / unionk - loss_iou = 1. - iouk * iouk - loss_iou = loss_iou * self._loss_weight - - return loss_iou + return iouk def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio, batch_size, is_gt): diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 6b6d51a51..6777608f7 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -34,7 +34,7 @@ class YOLOv3Loss(object): use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss instead of fluid.layers.yolov3_loss """ - __inject__ = ['iou_loss'] + __inject__ = ['iou_loss', 'iou_aware_loss'] __shared__ = ['use_fine_grained_loss'] def __init__(self, @@ -42,12 +42,14 @@ class YOLOv3Loss(object): ignore_thresh=0.7, label_smooth=True, use_fine_grained_loss=False, - iou_loss=None): + iou_loss=None, + iou_aware_loss=None): self._batch_size = batch_size self._ignore_thresh = ignore_thresh self._label_smooth = label_smooth self._use_fine_grained_loss = use_fine_grained_loss self._iou_loss = iou_loss + self._iou_aware_loss = iou_aware_loss def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors, anchor_masks, mask_anchors, num_classes, prefix_name): @@ -107,13 +109,15 @@ class YOLOv3Loss(object): "YOLOv3 output layer number not equal target number" downsample = 32 - if self._iou_loss is None: - loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] - else: - loss_xys, loss_whs, loss_ious, loss_objs, loss_clss = [], [], [], [], [] + loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + if self._iou_loss is not None: + loss_ious = [] + if self._iou_aware_loss is not None: + loss_iou_awares = [] for i, (output, target, anchors) in enumerate(zip(outputs, targets, mask_anchors)): an_num = len(anchors) // 2 + ioup, output = self._split_ioup(output, an_num, num_classes) x, y, w, h, obj, cls = self._split_output(output, an_num, num_classes) tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target) @@ -137,6 +141,15 @@ class YOLOv3Loss(object): loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) loss_ious.append(fluid.layers.reduce_mean(loss_iou)) + if self._iou_aware_loss is not None: + loss_iou_aware = self._iou_aware_loss( + ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample, + self._batch_size) + loss_iou_aware = loss_iou_aware * tobj + loss_iou_aware = fluid.layers.reduce_sum( + loss_iou_aware, dim=[1, 2, 3]) + loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) + loss_obj_pos, loss_obj_neg = self._calc_obj_loss( output, obj, tobj, gt_box, self._batch_size, anchors, num_classes, downsample, self._ignore_thresh) @@ -160,8 +173,24 @@ class YOLOv3Loss(object): } if self._iou_loss is not None: losses_all["loss_iou"] = fluid.layers.sum(loss_ious) + if self._iou_aware_loss is not None: + losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares) return losses_all + def _split_ioup(self, output, an_num, num_classes): + """ + Split output feature map to output, predicted iou + along channel dimension + """ + ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num]) + ioup = fluid.layers.sigmoid(ioup) + oriout = fluid.layers.slice( + output, + axes=[1], + starts=[an_num], + ends=[an_num * (num_classes + 6)]) + return (ioup, oriout) + def _split_output(self, output, an_num, num_classes): """ Split output feature map to x, y, w, h, objectness, classification -- GitLab