From 2cda4b289b7710937f4056c06b2868600c8154fd Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 4 Feb 2020 14:42:45 +0800 Subject: [PATCH] add libra rcnn (#198) add libra rcnn, including * libra loss * libra fpn(bfp) * libra sampling --- .gitignore | 1 + configs/libra_rcnn/README.md | 23 + .../libra_rcnn/libra_rcnn_r101_vd_fpn_1x.yml | 117 +++++ .../libra_rcnn/libra_rcnn_r50_vd_fpn_1x.yml | 117 +++++ ppdet/modeling/backbones/__init__.py | 4 +- ppdet/modeling/backbones/bfp.py | 156 ++++++ ppdet/modeling/backbones/nonlocal_helper.py | 201 ++++---- ppdet/modeling/losses/__init__.py | 2 + ppdet/modeling/losses/balanced_l1_loss.py | 73 +++ ppdet/modeling/ops.py | 457 +++++++++++++++++- ppdet/utils/bbox_utils.py | 83 ++++ tools/eval.py | 4 +- 12 files changed, 1114 insertions(+), 124 deletions(-) create mode 100644 configs/libra_rcnn/README.md create mode 100644 configs/libra_rcnn/libra_rcnn_r101_vd_fpn_1x.yml create mode 100644 configs/libra_rcnn/libra_rcnn_r50_vd_fpn_1x.yml create mode 100644 ppdet/modeling/backbones/bfp.py create mode 100644 ppdet/modeling/losses/balanced_l1_loss.py create mode 100644 ppdet/utils/bbox_utils.py diff --git a/.gitignore b/.gitignore index 99c138402..e3187193c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ # Byte-compiled / optimized / DLL files __pycache__/ +.ipynb_checkpoints/ *.py[cod] # C extensions diff --git a/configs/libra_rcnn/README.md b/configs/libra_rcnn/README.md new file mode 100644 index 000000000..63f3b7bd9 --- /dev/null +++ b/configs/libra_rcnn/README.md @@ -0,0 +1,23 @@ +# Libra R-CNN: Towards Balanced Learning for Object Detection + +## Introduction + +- Libra R-CNN: Towards Balanced Learning for Object Detection +: [https://arxiv.org/abs/1904.02701](https://arxiv.org/abs/1904.02701) + +``` +@inproceedings{pang2019libra, + title={Libra R-CNN: Towards Balanced Learning for Object Detection}, + author={Pang, Jiangmiao and Chen, Kai and Shi, Jianping and Feng, Huajun and Ouyang, Wanli and Dahua Lin}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + year={2019} +} +``` + + +## Model Zoo + +| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | +| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | +| ResNet50-vd-BFP | Faster | 2 | 1x | 18.247 | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r50_vd_fpn_1x.tar) | +| ResNet101-vd-BFP | Faster | 2 | 1x | 14.865 | 42.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r101_vd_fpn_1x.tar) | diff --git a/configs/libra_rcnn/libra_rcnn_r101_vd_fpn_1x.yml b/configs/libra_rcnn/libra_rcnn_r101_vd_fpn_1x.yml new file mode 100644 index 000000000..1c66935cb --- /dev/null +++ b/configs/libra_rcnn/libra_rcnn_r101_vd_fpn_1x.yml @@ -0,0 +1,117 @@ +architecture: FasterRCNN +max_iters: 90000 +snapshot_iter: 10000 +use_gpu: true +log_smooth_window: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar +weights: output/libra_rcnn_r101_vd_fpn_1x/model_final +metric: COCO +num_classes: 81 + +FasterRCNN: + backbone: ResNet + fpn: BFP + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: BBoxHead + bbox_assigner: LibraBBoxAssigner + +ResNet: + depth: 101 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: bn + variant: d + +BFP: + base_neck: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + refine_level: 2 + refine_type: nonlocal + nonlocal_reduction: 1.0 + +FPNRPNHead: + anchor_generator: + anchor_sizes: [32, 64, 128, 256, 512] + aspect_ratios: [0.5, 1.0, 2.0] + stride: [16.0, 16.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + post_nms_top_n: 2000 + pre_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + post_nms_top_n: 1000 + pre_nms_top_n: 1000 + +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + box_resolution: 7 + sampling_ratio: 2 + +LibraBBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] + bg_thresh_hi: 0.5 + bg_thresh_lo: 0.0 + fg_fraction: 0.25 + fg_thresh: 0.5 + +BBoxHead: + head: TwoFCHead + nms: + keep_top_k: 100 + nms_threshold: 0.5 + score_threshold: 0.05 + bbox_loss: BalancedL1Loss + +BalancedL1Loss: + alpha: 0.5 + gamma: 1.5 + beta: 1.0 + loss_weight: 1.0 + +TwoFCHead: + mlp_dim: 1024 + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [60000, 80000] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: '../faster_fpn_reader.yml' +TrainReader: + batch_size: 2 diff --git a/configs/libra_rcnn/libra_rcnn_r50_vd_fpn_1x.yml b/configs/libra_rcnn/libra_rcnn_r50_vd_fpn_1x.yml new file mode 100644 index 000000000..e28654cfb --- /dev/null +++ b/configs/libra_rcnn/libra_rcnn_r50_vd_fpn_1x.yml @@ -0,0 +1,117 @@ +architecture: FasterRCNN +max_iters: 90000 +snapshot_iter: 10000 +use_gpu: true +log_smooth_window: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar +weights: output/libra_rcnn_r50_vd_fpn_1x/model_final +metric: COCO +num_classes: 81 + +FasterRCNN: + backbone: ResNet + fpn: BFP + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: BBoxHead + bbox_assigner: LibraBBoxAssigner + +ResNet: + depth: 50 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: bn + variant: d + +BFP: + base_neck: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + refine_level: 2 + refine_type: nonlocal + nonlocal_reduction: 1.0 + +FPNRPNHead: + anchor_generator: + anchor_sizes: [32, 64, 128, 256, 512] + aspect_ratios: [0.5, 1.0, 2.0] + stride: [16.0, 16.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + post_nms_top_n: 2000 + pre_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + post_nms_top_n: 1000 + pre_nms_top_n: 1000 + +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + box_resolution: 7 + sampling_ratio: 2 + +LibraBBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] + bg_thresh_hi: 0.5 + bg_thresh_lo: 0.0 + fg_fraction: 0.25 + fg_thresh: 0.5 + +BBoxHead: + head: TwoFCHead + nms: + keep_top_k: 100 + nms_threshold: 0.5 + score_threshold: 0.05 + bbox_loss: BalancedL1Loss + +BalancedL1Loss: + alpha: 0.5 + gamma: 1.5 + beta: 1.0 + loss_weight: 1.0 + +TwoFCHead: + mlp_dim: 1024 + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [60000, 80000] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: '../faster_fpn_reader.yml' +TrainReader: + batch_size: 2 diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 22f78004b..7cd17ec09 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -27,6 +27,7 @@ from . import cb_resnet from . import res2net from . import hrnet from . import hrfpn +from . import bfp from .resnet import * from .resnext import * @@ -40,4 +41,5 @@ from .faceboxnet import * from .cb_resnet import * from .res2net import * from .hrnet import * -from .hrfpn import * \ No newline at end of file +from .hrfpn import * +from .bfp import * diff --git a/ppdet/modeling/backbones/bfp.py b/ppdet/modeling/backbones/bfp.py new file mode 100644 index 000000000..1dc03de69 --- /dev/null +++ b/ppdet/modeling/backbones/bfp.py @@ -0,0 +1,156 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +from .nonlocal_helper import add_space_nonlocal +from .fpn import FPN + +__all__ = ['BFP'] + + +@register +class BFP(object): + """ + Libra R-CNN, see https://arxiv.org/abs/1904.02701 + Args: + base_neck (dict): basic neck before balanced feature pyramid (bfp) + refine_level (int): index of integration and refine level of bfp + refine_type (str): refine type, None, conv or nonlocal + nonlocal_reduction (float): channel reduction level if refine_type is nonlocal + with_bias (bool): whether the nonlocal module contains bias + with_scale (bool): whether to scale feature in nonlocal module or not + """ + __inject__ = ['base_neck'] + + def __init__(self, + base_neck=FPN().__dict__, + refine_level=2, + refine_type="nonlocal", + nonlocal_reduction=1, + with_bias=True, + with_scale=False): + if isinstance(base_neck, dict): + self.base_neck = FPN(**base_neck) + self.refine_level = refine_level + self.refine_type = refine_type + self.nonlocal_reduction = nonlocal_reduction + self.with_bias = with_bias + self.with_scale = with_scale + + def get_output(self, body_dict): + # top-down order + res_dict, spatial_scale = self.base_neck.get_output(body_dict) + res_dict = self.get_output_bfp(res_dict) + return res_dict, spatial_scale + + def get_output_bfp(self, body_dict): + body_name_list = list(body_dict.keys()) + num_backbone_stages = len(body_name_list) + + self.num_levels = len(body_dict) + + # step 1: gather multi-level features by resize and average + feats = [] + refine_level_name = body_name_list[self.refine_level] + + for i in range(self.num_levels): + curr_fpn_name = body_name_list[i] + pool_stride = 2**(i - self.refine_level) + pool_size = [ + body_dict[refine_level_name].shape[2], + body_dict[refine_level_name].shape[3] + ] + if i > self.refine_level: + gathered = fluid.layers.pool2d( + input=body_dict[curr_fpn_name], + pool_type='max', + pool_size=pool_stride, + pool_stride=pool_stride, + ceil_mode=True, ) + else: + gathered = self._resize_input_tensor( + body_dict[curr_fpn_name], body_dict[refine_level_name], + 1.0 / pool_stride) + feats.append(gathered) + + bsf = sum(feats) / len(feats) + + # step 2: refine gathered features + if self.refine_type == "conv": + bsf = fluid.layers.conv2d( + bsf, + bsf.shape[1], + filter_size=3, + padding=1, + param_attr=ParamAttr(name="bsf_w"), + bias_attr=ParamAttr(name="bsf_b"), + name="bsf") + elif self.refine_type == "nonlocal": + dim_in = bsf.shape[1] + nonlocal_name = "nonlocal_bsf" + bsf = add_space_nonlocal( + bsf, + bsf.shape[1], + bsf.shape[1], + nonlocal_name, + int(bsf.shape[1] / self.nonlocal_reduction), + with_bias=self.with_bias, + with_scale=self.with_scale) + + # step 3: scatter refined features to multi-levels by a residual path + fpn_dict = {} + fpn_name_list = [] + for i in range(self.num_levels): + curr_fpn_name = body_name_list[i] + pool_stride = 2**(self.refine_level - i) + if i >= self.refine_level: + residual = self._resize_input_tensor( + bsf, body_dict[curr_fpn_name], 1.0 / pool_stride) + else: + residual = fluid.layers.pool2d( + input=bsf, + pool_type='max', + pool_size=pool_stride, + pool_stride=pool_stride, + ceil_mode=True, ) + + fpn_dict[curr_fpn_name] = residual + body_dict[curr_fpn_name] + fpn_name_list.append(curr_fpn_name) + + res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list]) + return res_dict + + def _resize_input_tensor(self, body_input, ref_output, scale): + shape = fluid.layers.shape(ref_output) + shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4]) + out_shape_ = shape_hw + out_shape = fluid.layers.cast(out_shape_, dtype='int32') + out_shape.stop_gradient = True + body_output = fluid.layers.resize_nearest( + body_input, scale=scale, actual_shape=out_shape) + return body_output diff --git a/ppdet/modeling/backbones/nonlocal_helper.py b/ppdet/modeling/backbones/nonlocal_helper.py index 9953fb31f..d33ae61bb 100644 --- a/ppdet/modeling/backbones/nonlocal_helper.py +++ b/ppdet/modeling/backbones/nonlocal_helper.py @@ -5,95 +5,72 @@ from __future__ import unicode_literals import paddle.fluid as fluid from paddle.fluid import ParamAttr +from paddle.fluid.initializer import ConstantInitializer - -nonlocal_params = { - "use_zero_init_conv" : False, - "conv_init_std" : 0.01, - "no_bias" : True, - "use_maxpool" : False, - "use_softmax" : True, - "use_bn" : False, - "use_scale" : True, # vital for the model prformance!!! - "use_affine" : False, - "bn_momentum" : 0.9, - "bn_epsilon" : 1.0000001e-5, - "bn_init_gamma" : 0.9, - "weight_decay_bn":1.e-4, - -} - -def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = 2): - cur = input - theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \ - filter_size = [1, 1], stride = [1, 1], \ - padding = [0, 0], \ - param_attr=ParamAttr(name = prefix + '_theta' + "_w", \ - initializer = fluid.initializer.Normal(loc = 0.0, - scale = nonlocal_params["conv_init_std"])), \ - bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \ - initializer = fluid.initializer.Constant(value = 0.)) \ - if not nonlocal_params["no_bias"] else False, \ - name = prefix + '_theta') +def space_nonlocal(input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=False, + with_scale=True): + theta = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_theta_w'), + bias_attr=ParamAttr( + name=prefix + '_theta_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False) theta_shape = theta.shape - theta_shape_op = fluid.layers.shape( theta ) + theta_shape_op = fluid.layers.shape(theta) theta_shape_op.stop_gradient = True - - if nonlocal_params["use_maxpool"]: - max_pool = fluid.layers.pool2d(input = cur, \ - pool_size = [max_pool_stride, max_pool_stride], \ - pool_type = 'max', \ - pool_stride = [max_pool_stride, max_pool_stride], \ - pool_padding = [0, 0], \ - name = prefix + '_pool') - else: - max_pool = cur - - phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \ - filter_size = [1, 1], stride = [1, 1], \ - padding = [0, 0], \ - param_attr = ParamAttr(name = prefix + '_phi' + "_w", \ - initializer = fluid.initializer.Normal(loc = 0.0, - scale = nonlocal_params["conv_init_std"])), \ - bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \ - initializer = fluid.initializer.Constant(value = 0.)) \ - if (nonlocal_params["no_bias"] == 0) else False, \ - name = prefix + '_phi') - phi_shape = phi.shape - - g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \ - filter_size = [1, 1], stride = [1, 1], \ - padding = [0, 0], \ - param_attr = ParamAttr(name = prefix + '_g' + "_w", \ - initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \ - bias_attr = ParamAttr(name = prefix + '_g' + "_b", \ - initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \ - name = prefix + '_g') - g_shape = g.shape + # we have to use explicit batch size (to support arbitrary spacetime size) # e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784) - theta = fluid.layers.reshape(theta, shape=(0, 0, -1) ) + theta = fluid.layers.reshape(theta, shape=(0, 0, -1)) theta = fluid.layers.transpose(theta, [0, 2, 1]) + + phi = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_phi_w'), + bias_attr=ParamAttr( + name=prefix + '_phi_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False, + name=prefix + '_phi') phi = fluid.layers.reshape(phi, [0, 0, -1]) - theta_phi = fluid.layers.matmul(theta, phi, name = prefix + '_affinity') + + theta_phi = fluid.layers.matmul(theta, phi) + + g = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_g_w'), + bias_attr=ParamAttr( + name=prefix + '_g_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False, + name=prefix + '_g') g = fluid.layers.reshape(g, [0, 0, -1]) - - if nonlocal_params["use_softmax"]: - if nonlocal_params["use_scale"]: - theta_phi_sc = fluid.layers.scale(theta_phi, scale = dim_inner**-.5) - else: - theta_phi_sc = theta_phi - p = fluid.layers.softmax(theta_phi_sc, name = prefix + '_affinity' + '_prob') - else: - # not clear about what is doing in xlw's code - p = None # not implemented - raise "Not implemented when not use softmax" + + # scale + if with_scale: + theta_phi = fluid.layers.scale(theta_phi, scale=dim_inner**-.5) + p = fluid.layers.softmax(theta_phi) # note g's axis[2] corresponds to p's axis[2] # e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1) p = fluid.layers.transpose(p, [0, 2, 1]) - t = fluid.layers.matmul(g, p, name = prefix + '_y') + t = fluid.layers.matmul(g, p) # reshape back # e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14) @@ -104,56 +81,40 @@ def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = t_re = fluid.layers.reshape(t, shape=[n, ch, h, w]) blob_out = t_re - blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \ - filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \ - param_attr = ParamAttr(name = prefix + '_out' + "_w", \ - initializer = fluid.initializer.Constant(value = 0.) \ - if nonlocal_params["use_zero_init_conv"] \ - else fluid.initializer.Normal(loc = 0.0, - scale = nonlocal_params["conv_init_std"])), \ - bias_attr = ParamAttr(name = prefix + '_out' + "_b", \ - initializer = fluid.initializer.Constant(value = 0.)) \ - if (nonlocal_params["no_bias"] == 0) else False, \ - name = prefix + '_out') + blob_out = fluid.layers.conv2d( + input=blob_out, + num_filters=dim_out, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr( + name=prefix + '_out_w', initializer=ConstantInitializer(value=0.0)), + bias_attr=ParamAttr( + name=prefix + '_out_b', initializer=ConstantInitializer(value=0.0)) + if with_bias else False, + name=prefix + '_out') blob_out_shape = blob_out.shape - - - if nonlocal_params["use_bn"]: - bn_name = prefix + "_bn" - blob_out = fluid.layers.batch_norm(blob_out, \ - # is_test = test_mode, \ - momentum = nonlocal_params["bn_momentum"], \ - epsilon = nonlocal_params["bn_epsilon"], \ - name = bn_name, \ - param_attr = ParamAttr(name = bn_name + "_s", \ - initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \ - regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \ - bias_attr = ParamAttr(name = bn_name + "_b", \ - regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \ - moving_mean_name = bn_name + "_rm", \ - moving_variance_name = bn_name + "_riv") # add bn - - if nonlocal_params["use_affine"]: - affine_scale = fluid.layers.create_parameter(\ - shape=[blob_out_shape[1]], dtype = blob_out.dtype, \ - attr=ParamAttr(name=prefix + '_affine' + '_s'), \ - default_initializer = fluid.initializer.Constant(value = 1.)) - affine_bias = fluid.layers.create_parameter(\ - shape=[blob_out_shape[1]], dtype = blob_out.dtype, \ - attr=ParamAttr(name=prefix + '_affine' + '_b'), \ - default_initializer = fluid.initializer.Constant(value = 0.)) - blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \ - bias = affine_bias, name = prefix + '_affine') # add affine - return blob_out -def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner ): +def add_space_nonlocal(input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=False, + with_scale=True): ''' add_space_nonlocal: Non-local Neural Networks: see https://arxiv.org/abs/1711.07971 ''' - conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner) - output = fluid.layers.elementwise_add(input, conv, name = prefix + '_sum') + conv = space_nonlocal( + input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=with_bias, + with_scale=with_scale) + output = input + conv return output - diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 3179d3b99..eabac9bf8 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -19,9 +19,11 @@ from . import smooth_l1_loss from . import giou_loss from . import diou_loss from . import iou_loss +from . import balanced_l1_loss from .yolo_loss import * from .smooth_l1_loss import * from .giou_loss import * from .diou_loss import * from .iou_loss import * +from .balanced_l1_loss import * diff --git a/ppdet/modeling/losses/balanced_l1_loss.py b/ppdet/modeling/losses/balanced_l1_loss.py new file mode 100644 index 000000000..08e2087f8 --- /dev/null +++ b/ppdet/modeling/losses/balanced_l1_loss.py @@ -0,0 +1,73 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['BalancedL1Loss'] + + +@register +@serializable +class BalancedL1Loss(object): + """ + Balanced L1 Loss, see https://arxiv.org/abs/1904.02701 + Args: + alpha (float): hyper parameter of BalancedL1Loss, see more details in the paper + gamma (float): hyper parameter of BalancedL1Loss, see more details in the paper + beta (float): hyper parameter of BalancedL1Loss, see more details in the paper + loss_weights (float): loss weight + """ + + def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0): + super(BalancedL1Loss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.beta = beta + self.loss_weight = loss_weight + + def __call__( + self, + x, + y, + inside_weight=None, + outside_weight=None, ): + alpha = self.alpha + gamma = self.gamma + beta = self.beta + loss_weight = self.loss_weight + diff = fluid.layers.abs(x - y) + b = np.e**(gamma / alpha) - 1 + less_beta = diff < beta + ge_beta = diff >= beta + less_beta = fluid.layers.cast(x=less_beta, dtype='float32') + ge_beta = fluid.layers.cast(x=ge_beta, dtype='float32') + less_beta.stop_gradient = True + ge_beta.stop_gradient = True + loss_1 = less_beta * ( + alpha / b * (b * diff + 1) * fluid.layers.log(b * diff / beta + 1) - + alpha * diff) + loss_2 = ge_beta * (gamma * diff + gamma / b - alpha * beta) + iou_weights = 1.0 + if inside_weight is not None and outside_weight is not None: + iou_weights = inside_weight * outside_weight + loss = (loss_1 + loss_2) * iou_weights + loss = fluid.layers.reduce_sum(loss, dim=-1) * loss_weight + return loss diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 1a4b5f4ec..a94e1f451 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -19,12 +19,13 @@ from paddle import fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay from ppdet.core.workspace import register, serializable +from ppdet.utils.bbox_utils import bbox_overlaps, box_to_delta __all__ = [ 'AnchorGenerator', 'DropBlock', 'RPNTargetAssign', 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDOutputDecoder', 'RetinaTargetAssign', - 'RetinaOutputDecoder', 'ConvNorm', 'MultiClassSoftNMS' + 'RetinaOutputDecoder', 'ConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner' ] @@ -546,6 +547,460 @@ class BBoxAssigner(object): self.use_random = shuffle_before_sample +@register +class LibraBBoxAssigner(object): + def __init__(self, + batch_size_per_im=512, + fg_fraction=.25, + fg_thresh=.5, + bg_thresh_hi=.5, + bg_thresh_lo=0., + bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], + num_classes=81, + shuffle_before_sample=True, + is_cls_agnostic=False, + num_bins=3): + super(LibraBBoxAssigner, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.fg_fraction = fg_fraction + self.fg_thresh = fg_thresh + self.bg_thresh_hi = bg_thresh_hi + self.bg_thresh_lo = bg_thresh_lo + self.bbox_reg_weights = bbox_reg_weights + self.class_nums = num_classes + self.use_random = shuffle_before_sample + self.is_cls_agnostic = is_cls_agnostic + self.num_bins = num_bins + + def __call__( + self, + rpn_rois, + gt_classes, + is_crowd, + gt_boxes, + im_info, ): + return self.generate_proposal_label_libra( + rpn_rois=rpn_rois, + gt_classes=gt_classes, + is_crowd=is_crowd, + gt_boxes=gt_boxes, + im_info=im_info, + batch_size_per_im=self.batch_size_per_im, + fg_fraction=self.fg_fraction, + fg_thresh=self.fg_thresh, + bg_thresh_hi=self.bg_thresh_hi, + bg_thresh_lo=self.bg_thresh_lo, + bbox_reg_weights=self.bbox_reg_weights, + class_nums=self.class_nums, + use_random=self.use_random, + is_cls_agnostic=self.is_cls_agnostic, + is_cascade_rcnn=False) + + def generate_proposal_label_libra( + self, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, + batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, + bg_thresh_lo, bbox_reg_weights, class_nums, use_random, + is_cls_agnostic, is_cascade_rcnn): + num_bins = self.num_bins + + def create_tmp_var(program, name, dtype, shape, lod_level=None): + return program.current_block().create_var( + name=name, dtype=dtype, shape=shape, lod_level=lod_level) + + def _sample_pos(max_overlaps, max_classes, pos_inds, num_expected): + if len(pos_inds) <= num_expected: + return pos_inds + else: + unique_gt_inds = np.unique(max_classes[pos_inds]) + num_gts = len(unique_gt_inds) + num_per_gt = int(round(num_expected / float(num_gts)) + 1) + + sampled_inds = [] + for i in unique_gt_inds: + inds = np.nonzero(max_classes == i)[0] + before_len = len(inds) + inds = list(set(inds) & set(pos_inds)) + after_len = len(inds) + if len(inds) > num_per_gt: + inds = np.random.choice( + inds, size=num_per_gt, replace=False) + sampled_inds.extend(list(inds)) # combine as a new sampler + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array( + list(set(pos_inds) - set(sampled_inds))) + assert len(sampled_inds)+len(extra_inds) == len(pos_inds), \ + "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format( + len(sampled_inds), len(extra_inds), len(pos_inds)) + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, size=num_extra, replace=False) + sampled_inds.extend(extra_inds.tolist()) + elif len(sampled_inds) > num_expected: + sampled_inds = np.random.choice( + sampled_inds, size=num_expected, replace=False) + return sampled_inds + + def sample_via_interval(max_overlaps, full_set, num_expected, floor_thr, + num_bins, bg_thresh_hi): + max_iou = max_overlaps.max() + iou_interval = (max_iou - floor_thr) / num_bins + per_num_expected = int(num_expected / num_bins) + + sampled_inds = [] + for i in range(num_bins): + start_iou = floor_thr + i * iou_interval + end_iou = floor_thr + (i + 1) * iou_interval + + tmp_set = set( + np.where( + np.logical_and(max_overlaps >= start_iou, max_overlaps < + end_iou))[0]) + tmp_inds = list(tmp_set & full_set) + + if len(tmp_inds) > per_num_expected: + tmp_sampled_set = np.random.choice( + tmp_inds, size=per_num_expected, replace=False) + else: + tmp_sampled_set = np.array(tmp_inds, dtype=np.int) + sampled_inds.append(tmp_sampled_set) + + sampled_inds = np.concatenate(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(full_set - set(sampled_inds))) + assert len(sampled_inds)+len(extra_inds) == len(full_set), \ + "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format( + len(sampled_inds), len(extra_inds), len(full_set)) + + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, num_extra, replace=False) + sampled_inds = np.concatenate([sampled_inds, extra_inds]) + + return sampled_inds + + def _sample_neg(max_overlaps, + max_classes, + neg_inds, + num_expected, + floor_thr=-1, + floor_fraction=0, + num_bins=3, + bg_thresh_hi=0.5): + if len(neg_inds) <= num_expected: + return neg_inds + else: + # balance sampling for negative samples + neg_set = set(neg_inds) + if floor_thr > 0: + floor_set = set( + np.where( + np.logical_and(max_overlaps >= 0, max_overlaps < + floor_thr))[0]) + iou_sampling_set = set( + np.where(max_overlaps >= floor_thr)[0]) + elif floor_thr == 0: + floor_set = set(np.where(max_overlaps == 0)[0]) + iou_sampling_set = set( + np.where(max_overlaps > floor_thr)[0]) + else: + floor_set = set() + iou_sampling_set = set( + np.where(max_overlaps > floor_thr)[0]) + floor_thr = 0 + + floor_neg_inds = list(floor_set & neg_set) + iou_sampling_neg_inds = list(iou_sampling_set & neg_set) + + num_expected_iou_sampling = int(num_expected * + (1 - floor_fraction)) + if len(iou_sampling_neg_inds) > num_expected_iou_sampling: + if num_bins >= 2: + iou_sampled_inds = sample_via_interval( + max_overlaps, + set(iou_sampling_neg_inds), + num_expected_iou_sampling, floor_thr, num_bins, + bg_thresh_hi) + else: + iou_sampled_inds = np.random.choice( + iou_sampling_neg_inds, + size=num_expected_iou_sampling, + replace=False) + else: + iou_sampled_inds = np.array( + iou_sampling_neg_inds, dtype=np.int) + num_expected_floor = num_expected - len(iou_sampled_inds) + if len(floor_neg_inds) > num_expected_floor: + sampled_floor_inds = np.random.choice( + floor_neg_inds, size=num_expected_floor, replace=False) + else: + sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int) + sampled_inds = np.concatenate( + (sampled_floor_inds, iou_sampled_inds)) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(neg_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, size=num_extra, replace=False) + sampled_inds = np.concatenate((sampled_inds, extra_inds)) + return sampled_inds + + def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, + batch_size_per_im, fg_fraction, fg_thresh, + bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, + class_nums, use_random, is_cls_agnostic, + is_cascade_rcnn): + rois_per_image = int(batch_size_per_im) + fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) + + # Roidb + im_scale = im_info[2] + inv_im_scale = 1. / im_scale + rpn_rois = rpn_rois * inv_im_scale + if is_cascade_rcnn: + rpn_rois = rpn_rois[gt_boxes.shape[0]:, :] + boxes = np.vstack([gt_boxes, rpn_rois]) + gt_overlaps = np.zeros((boxes.shape[0], class_nums)) + box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) + if len(gt_boxes) > 0: + proposal_to_gt_overlaps = bbox_overlaps(boxes, gt_boxes) + + overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1) + overlaps_max = proposal_to_gt_overlaps.max(axis=1) + # Boxes which with non-zero overlap with gt boxes + overlapped_boxes_ind = np.where(overlaps_max > 0)[0] + + overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[ + overlapped_boxes_ind]] + + for idx in range(len(overlapped_boxes_ind)): + gt_overlaps[overlapped_boxes_ind[ + idx], overlapped_boxes_gt_classes[idx]] = overlaps_max[ + overlapped_boxes_ind[idx]] + box_to_gt_ind_map[overlapped_boxes_ind[ + idx]] = overlaps_argmax[overlapped_boxes_ind[idx]] + + crowd_ind = np.where(is_crowd)[0] + gt_overlaps[crowd_ind] = -1 + + max_overlaps = gt_overlaps.max(axis=1) + max_classes = gt_overlaps.argmax(axis=1) + + # Cascade RCNN Decode Filter + if is_cascade_rcnn: + ws = boxes[:, 2] - boxes[:, 0] + 1 + hs = boxes[:, 3] - boxes[:, 1] + 1 + keep = np.where((ws > 0) & (hs > 0))[0] + boxes = boxes[keep] + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + bg_inds = np.where((max_overlaps < bg_thresh_hi) & ( + max_overlaps >= bg_thresh_lo))[0] + fg_rois_per_this_image = fg_inds.shape[0] + bg_rois_per_this_image = bg_inds.shape[0] + else: + # Foreground + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + fg_rois_per_this_image = np.minimum(fg_rois_per_im, + fg_inds.shape[0]) + # Sample foreground if there are too many + if fg_inds.shape[0] > fg_rois_per_this_image: + if use_random: + fg_inds = _sample_pos(max_overlaps, max_classes, + fg_inds, fg_rois_per_this_image) + fg_inds = fg_inds[:fg_rois_per_this_image] + + # Background + bg_inds = np.where((max_overlaps < bg_thresh_hi) & ( + max_overlaps >= bg_thresh_lo))[0] + bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image + bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, + bg_inds.shape[0]) + assert bg_rois_per_this_image >= 0, "bg_rois_per_this_image must be >= 0 but got {}".format( + bg_rois_per_this_image) + + # Sample background if there are too many + if bg_inds.shape[0] > bg_rois_per_this_image: + if use_random: + # libra neg sample + bg_inds = _sample_neg( + max_overlaps, + max_classes, + bg_inds, + bg_rois_per_this_image, + num_bins=num_bins, + bg_thresh_hi=bg_thresh_hi) + bg_inds = bg_inds[:bg_rois_per_this_image] + + keep_inds = np.append(fg_inds, bg_inds) + sampled_labels = max_classes[keep_inds] # N x 1 + sampled_labels[fg_rois_per_this_image:] = 0 + sampled_boxes = boxes[keep_inds] # N x 324 + sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] + sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] + bbox_label_targets = _compute_targets( + sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights) + bbox_targets, bbox_inside_weights = _expand_bbox_targets( + bbox_label_targets, class_nums, is_cls_agnostic) + bbox_outside_weights = np.array( + bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) + # Scale rois + sampled_rois = sampled_boxes * im_scale + + # Faster RCNN blobs + frcn_blobs = dict( + rois=sampled_rois, + labels_int32=sampled_labels, + bbox_targets=bbox_targets, + bbox_inside_weights=bbox_inside_weights, + bbox_outside_weights=bbox_outside_weights) + return frcn_blobs + + def _compute_targets(roi_boxes, gt_boxes, labels, bbox_reg_weights): + assert roi_boxes.shape[0] == gt_boxes.shape[0] + assert roi_boxes.shape[1] == 4 + assert gt_boxes.shape[1] == 4 + + targets = np.zeros(roi_boxes.shape) + bbox_reg_weights = np.asarray(bbox_reg_weights) + targets = box_to_delta( + ex_boxes=roi_boxes, gt_boxes=gt_boxes, weights=bbox_reg_weights) + + return np.hstack([labels[:, np.newaxis], targets]).astype( + np.float32, copy=False) + + def _expand_bbox_targets(bbox_targets_input, class_nums, + is_cls_agnostic): + class_labels = bbox_targets_input[:, 0] + fg_inds = np.where(class_labels > 0)[0] + bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums + if not is_cls_agnostic else 4 * 2)) + bbox_inside_weights = np.zeros(bbox_targets.shape) + for ind in fg_inds: + class_label = int(class_labels[ + ind]) if not is_cls_agnostic else 1 + start_ind = class_label * 4 + end_ind = class_label * 4 + 4 + bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, + 1:] + bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, + 1.0) + return bbox_targets, bbox_inside_weights + + def generate_func( + rpn_rois, + gt_classes, + is_crowd, + gt_boxes, + im_info, ): + rpn_rois_lod = rpn_rois.lod()[0] + gt_classes_lod = gt_classes.lod()[0] + + # convert + rpn_rois = np.array(rpn_rois) + gt_classes = np.array(gt_classes) + is_crowd = np.array(is_crowd) + gt_boxes = np.array(gt_boxes) + im_info = np.array(im_info) + + rois = [] + labels_int32 = [] + bbox_targets = [] + bbox_inside_weights = [] + bbox_outside_weights = [] + lod = [0] + + for idx in range(len(rpn_rois_lod) - 1): + rois_si = rpn_rois_lod[idx] + rois_ei = rpn_rois_lod[idx + 1] + + gt_si = gt_classes_lod[idx] + gt_ei = gt_classes_lod[idx + 1] + frcn_blobs = _sample_rois( + rpn_rois[rois_si:rois_ei], gt_classes[gt_si:gt_ei], + is_crowd[gt_si:gt_ei], gt_boxes[gt_si:gt_ei], im_info[idx], + batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, + bg_thresh_lo, bbox_reg_weights, class_nums, use_random, + is_cls_agnostic, is_cascade_rcnn) + lod.append(frcn_blobs['rois'].shape[0] + lod[-1]) + rois.append(frcn_blobs['rois']) + labels_int32.append(frcn_blobs['labels_int32'].reshape(-1, 1)) + bbox_targets.append(frcn_blobs['bbox_targets']) + bbox_inside_weights.append(frcn_blobs['bbox_inside_weights']) + bbox_outside_weights.append(frcn_blobs['bbox_outside_weights']) + + rois = np.vstack(rois) + labels_int32 = np.vstack(labels_int32) + bbox_targets = np.vstack(bbox_targets) + bbox_inside_weights = np.vstack(bbox_inside_weights) + bbox_outside_weights = np.vstack(bbox_outside_weights) + + # create lod-tensor for return + # notice that the func create_lod_tensor does not work well here + ret_rois = fluid.LoDTensor() + ret_rois.set_lod([lod]) + ret_rois.set(rois.astype("float32"), fluid.CPUPlace()) + + ret_labels_int32 = fluid.LoDTensor() + ret_labels_int32.set_lod([lod]) + ret_labels_int32.set(labels_int32.astype("int32"), fluid.CPUPlace()) + + ret_bbox_targets = fluid.LoDTensor() + ret_bbox_targets.set_lod([lod]) + ret_bbox_targets.set( + bbox_targets.astype("float32"), fluid.CPUPlace()) + + ret_bbox_inside_weights = fluid.LoDTensor() + ret_bbox_inside_weights.set_lod([lod]) + ret_bbox_inside_weights.set( + bbox_inside_weights.astype("float32"), fluid.CPUPlace()) + + ret_bbox_outside_weights = fluid.LoDTensor() + ret_bbox_outside_weights.set_lod([lod]) + ret_bbox_outside_weights.set( + bbox_outside_weights.astype("float32"), fluid.CPUPlace()) + + return ret_rois, ret_labels_int32, ret_bbox_targets, ret_bbox_inside_weights, ret_bbox_outside_weights + + rois = create_tmp_var( + fluid.default_main_program(), + name=None, #'rois', + dtype='float32', + shape=[-1, 4], ) + bbox_inside_weights = create_tmp_var( + fluid.default_main_program(), + name=None, #'bbox_inside_weights', + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + bbox_outside_weights = create_tmp_var( + fluid.default_main_program(), + name=None, #'bbox_outside_weights', + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + bbox_targets = create_tmp_var( + fluid.default_main_program(), + name=None, #'bbox_targets', + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + labels_int32 = create_tmp_var( + fluid.default_main_program(), + name=None, #'labels_int32', + dtype='int32', + shape=[-1, 1], ) + + outs = [ + rois, labels_int32, bbox_targets, bbox_inside_weights, + bbox_outside_weights + ] + + fluid.layers.py_func( + func=generate_func, + x=[rpn_rois, gt_classes, is_crowd, gt_boxes, im_info], + out=outs) + return outs + + @register class RoIAlign(object): __op__ = fluid.layers.roi_align diff --git a/ppdet/utils/bbox_utils.py b/ppdet/utils/bbox_utils.py new file mode 100644 index 000000000..ff16e8b9d --- /dev/null +++ b/ppdet/utils/bbox_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np + +import paddle.fluid as fluid + +__all__ = ["bbox_overlaps", "box_to_delta"] + +logger = logging.getLogger(__name__) + + +def bbox_overlaps(boxes_1, boxes_2): + ''' + bbox_overlaps + boxes_1: x1, y, x2, y2 + boxes_2: x1, y, x2, y2 + ''' + assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4 + + num_1 = boxes_1.shape[0] + num_2 = boxes_2.shape[0] + + x1_1 = boxes_1[:, 0:1] + y1_1 = boxes_1[:, 1:2] + x2_1 = boxes_1[:, 2:3] + y2_1 = boxes_1[:, 3:4] + area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1) + + x1_2 = boxes_2[:, 0].transpose() + y1_2 = boxes_2[:, 1].transpose() + x2_2 = boxes_2[:, 2].transpose() + y2_2 = boxes_2[:, 3].transpose() + area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1) + + xx1 = np.maximum(x1_1, x1_2) + yy1 = np.maximum(y1_1, y1_2) + xx2 = np.minimum(x2_1, x2_2) + yy2 = np.minimum(y2_1, y2_2) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (area_1 + area_2 - inter) + return ovr + + +def box_to_delta(ex_boxes, gt_boxes, weights): + """ box_to_delta """ + ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1 + ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1 + ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w + ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h + + gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1 + gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1 + gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w + gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h + + dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] + dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] + dw = (np.log(gt_w / ex_w)) / weights[2] + dh = (np.log(gt_h / ex_h)) / weights[3] + + targets = np.vstack([dx, dy, dw, dh]).transpose() + return targets diff --git a/tools/eval.py b/tools/eval.py index bdd21a1e5..ba028b3c5 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -87,6 +87,8 @@ def main(): reader = create_reader(cfg.EvalReader) loader.set_sample_list_generator(reader, place) + dataset = cfg['EvalReader']['dataset'] + # eval already exists json file if FLAGS.json_eval: logger.info( @@ -123,8 +125,6 @@ def main(): callable(model.is_bbox_normalized): is_bbox_normalized = model.is_bbox_normalized() - dataset = cfg['EvalReader']['dataset'] - sub_eval_prog = None sub_keys = None sub_values = None -- GitLab