diff --git a/configs/rotate/README.md b/configs/rotate/README.md index db3a6b0ddc12e50627f8d42805ede3b1817b2c46..574cb4ed5ece2992b7d04587bac977ba19f0d5a1 100644 --- a/configs/rotate/README.md +++ b/configs/rotate/README.md @@ -16,6 +16,7 @@ | 模型 | mAP | 学习率策略 | 角度表示 | 数据增广 | GPU数目 | 每GPU图片数目 | 模型下载 | 配置文件 | |:---:|:----:|:---------:|:-----:|:--------:|:-----:|:------------:|:-------:|:------:| | [S2ANet](./s2anet/README.md) | 73.84 | 2x | le135 | - | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/s2anet_alignconv_2x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/s2anet/s2anet_alignconv_2x_dota.yml) | +| [FCOSR](./fcosr/README.md) | 76.62 | 3x | oc | - | 4 | 4 | [model](https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/fcosr/fcosr_x50_3x_dota.yml) | **注意:** diff --git a/configs/rotate/README_en.md b/configs/rotate/README_en.md index 03c4d2cee3ff61dc1001c8adcd81864a597935d2..ef5160ec9f4f0b8f8670a7a0989a05b2be5b982d 100644 --- a/configs/rotate/README_en.md +++ b/configs/rotate/README_en.md @@ -15,6 +15,7 @@ Rotated object detection is used to detect rectangular bounding boxes with angle | Model | mAP | Lr Scheduler | Angle | Aug | GPU Number | images/GPU | download | config | |:---:|:----:|:---------:|:-----:|:--------:|:-----:|:------------:|:-------:|:------:| | [S2ANet](./s2anet/README_en.md) | 73.84 | 2x | le135 | - | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/s2anet_alignconv_2x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/s2anet/s2anet_alignconv_2x_dota.yml) | +| [FCOSR](./fcosr/README_en.md) | 76.62 | 3x | oc | - | 4 | 4 | [model](https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/fcosr/fcosr_x50_3x_dota.yml) | **Notes:** diff --git a/configs/rotate/fcosr/README.md b/configs/rotate/fcosr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0113ee1f8d6a9796a8bb91c02787308dd8bbac48 --- /dev/null +++ b/configs/rotate/fcosr/README.md @@ -0,0 +1,91 @@ +简体中文 | [English](README_en.md) + +# FCOSR + +## 内容 +- [简介](#简介) +- [模型库](#模型库) +- [使用说明](#使用说明) +- [预测部署](#预测部署) +- [引用](#引用) + +## 简介 + +[FCOSR](https://arxiv.org/abs/2111.10780)是基于[FCOS](https://arxiv.org/abs/1904.01355)的单阶段Anchor-Free的旋转框检测算法。FCOSR主要聚焦于旋转框的标签匹配策略,提出了椭圆中心采样和模糊样本标签匹配的方法。在loss方面,FCOSR使用了[ProbIoU](https://arxiv.org/abs/2106.06072)避免边界不连续性问题。 + +## 模型库 + +| 模型 | Backbone | mAP | 学习率策略 | 角度表示 | 数据增广 | GPU数目 | 每GPU图片数目 | 模型下载 | 配置文件 | +|:---:|:--------:|:----:|:---------:|:-----:|:--------:|:-----:|:------------:|:-------:|:------:| +| FCOSR-M | ResNeXt-50 | 76.62 | 3x | oc | - | 4 | 4 | [model](https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/fcosr/fcosr_x50_3x_dota.yml) | + +**注意:** + +- 如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lrnew = lrdefault * (batch_sizenew * GPU_numbernew) / (batch_sizedefault * GPU_numberdefault)** 调整学习率。 +- 模型库中的模型默认使用单尺度训练单尺度测试。如果数据增广一栏标明MS,意味着使用多尺度训练和多尺度测试。如果数据增广一栏标明RR,意味着使用RandomRotate数据增广进行训练。 + +## 使用说明 + +参考[数据准备](../README.md#数据准备)准备数据。 + +### 训练 + +GPU单卡训练 +``` bash +CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml +``` + +GPU多卡训练 +``` bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml +``` + +### 预测 + +执行以下命令预测单张图片,图片预测结果会默认保存在`output`文件夹下面 +``` bash +python tools/infer.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams --infer_img=demo/P0861__1.0__1154___824.png --draw_threshold=0.5 +``` + +### DOTA数据集评估 + +参考[DOTA Task](https://captain-whu.github.io/DOTA/tasks.html), 评估DOTA数据集需要生成一个包含所有检测结果的zip文件,每一类的检测结果储存在一个txt文件中,txt文件中每行格式为:`image_name score x1 y1 x2 y2 x3 y3 x4 y4`。将生成的zip文件提交到[DOTA Evaluation](https://captain-whu.github.io/DOTA/evaluation.html)的Task1进行评估。你可以执行以下命令得到test数据集的预测结果: +``` bash +python tools/infer.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams --infer_dir=/path/to/test/images --output_dir=output_fcosr --visualize=False --save_results=True +``` +将预测结果处理成官网评估所需要的格式: +``` bash +python configs/rotate/tools/generate_result.py --pred_txt_dir=output_fcosr/ --output_dir=submit/ --data_type=dota10 + +zip -r submit.zip submit +``` + +## 预测部署 + +部署教程请参考[预测部署](../../../deploy/README.md) + +## 引用 + +``` +@article{li2021fcosr, + title={Fcosr: A simple anchor-free rotated detector for aerial object detection}, + author={Li, Zhonghua and Hou, Biao and Wu, Zitong and Jiao, Licheng and Ren, Bo and Yang, Chen}, + journal={arXiv preprint arXiv:2111.10780}, + year={2021} +} + +@inproceedings{tian2019fcos, + title={Fcos: Fully convolutional one-stage object detection}, + author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={9627--9636}, + year={2019} +} + +@article{llerena2021gaussian, + title={Gaussian Bounding Boxes and Probabilistic Intersection-over-Union for Object Detection}, + author={Llerena, Jeffri M and Zeni, Luis Felipe and Kristen, Lucas N and Jung, Claudio}, + journal={arXiv preprint arXiv:2106.06072}, + year={2021} +} +``` diff --git a/configs/rotate/fcosr/README_en.md b/configs/rotate/fcosr/README_en.md new file mode 100644 index 0000000000000000000000000000000000000000..cf8e49ae47ad2d12badfd5ddfa89cbb3bc3eabe1 --- /dev/null +++ b/configs/rotate/fcosr/README_en.md @@ -0,0 +1,92 @@ +English | [简体中文](README.md) + +# FCOSR + +## Content +- [Introduction](#Introduction) +- [Model Zoo](#Model-Zoo) +- [Getting Start](#Getting-Start) +- [Deployment](#Deployment) +- [Citations](#Citations) + +## Introduction + +[FCOSR](https://arxiv.org/abs/2111.10780) is one stage anchor-free model based on [FCOS](https://arxiv.org/abs/1904.01355). FCOSR focuses on the label assignment strategy for oriented bounding boxes and proposes ellipse center sampling method and fuzzy sample assignment strategy. In terms of loss, FCOSR uses [ProbIoU](https://arxiv.org/abs/2106.06072) to avoid boundary discontinuity problem. + +## Model Zoo + +| Model | Backbone | mAP | Lr Scheduler | Angle | Aug | GPU Number | images/GPU | download | config | +|:---:|:--------:|:----:|:---------:|:-----:|:--------:|:-----:|:------------:|:-------:|:------:| +| FCOSR-M | ResNeXt-50 | 76.62 | 3x | oc | - | 4 | 4 | [model](https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/fcosr/fcosr_x50_3x_dota.yml) | + +**Notes:** + +- if **GPU number** or **mini-batch size** is changed, **learning rate** should be adjusted according to the formula **lrnew = lrdefault * (batch_sizenew * GPU_numbernew) / (batch_sizedefault * GPU_numberdefault)**. +- Models in model zoo is trained and tested with single scale by default. If `MS` is indicated in the data augmentation column, it means that multi-scale training and multi-scale testing are used. If `RR` is indicated in the data augmentation column, it means that RandomRotate data augmentation is used for training. + +## Getting Start + +Refer to [Data-Preparation](../README_en.md#Data-Preparation) to prepare data. + +### Training + +Single GPU Training +``` bash +CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml +``` + +Multiple GPUs Training +``` bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml +``` + +### Inference + +Run the follow command to infer single image, the result of inference will be saved in `output` directory by default. + +``` bash +python tools/infer.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams --infer_img=demo/P0861__1.0__1154___824.png --draw_threshold=0.5 +``` + +### Evaluation on DOTA Dataset +Refering to [DOTA Task](https://captain-whu.github.io/DOTA/tasks.html), You need to submit a zip file containing results for all test images for evaluation. The detection results of each category are stored in a txt file, each line of which is in the following format +`image_id score x1 y1 x2 y2 x3 y3 x4 y4`. To evaluate, you should submit the generated zip file to the Task1 of [DOTA Evaluation](https://captain-whu.github.io/DOTA/evaluation.html). You can run the following command to get the inference results of test dataset: +``` bash +python tools/infer.py -c configs/rotate/fcosr/fcosr_x50_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/fcosr_x50_3x_dota.pdparams --infer_dir=/path/to/test/images --output_dir=output_fcosr --visualize=False --save_results=True +``` +Process the prediction results into the format required for the official website evaluation: +``` bash +python configs/rotate/tools/generate_result.py --pred_txt_dir=output_fcosr/ --output_dir=submit/ --data_type=dota10 + +zip -r submit.zip submit +``` + +## Deployment + +Please refer to the deployment tutorial[Deployment](../../../deploy/README_en.md) + +## Citations + +``` +@article{li2021fcosr, + title={Fcosr: A simple anchor-free rotated detector for aerial object detection}, + author={Li, Zhonghua and Hou, Biao and Wu, Zitong and Jiao, Licheng and Ren, Bo and Yang, Chen}, + journal={arXiv preprint arXiv:2111.10780}, + year={2021} +} + +@inproceedings{tian2019fcos, + title={Fcos: Fully convolutional one-stage object detection}, + author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={9627--9636}, + year={2019} +} + +@article{llerena2021gaussian, + title={Gaussian Bounding Boxes and Probabilistic Intersection-over-Union for Object Detection}, + author={Llerena, Jeffri M and Zeni, Luis Felipe and Kristen, Lucas N and Jung, Claudio}, + journal={arXiv preprint arXiv:2106.06072}, + year={2021} +} +``` diff --git a/configs/rotate/fcosr/_base_/fcosr_reader.yml b/configs/rotate/fcosr/_base_/fcosr_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..c08af1e8ffd050c1e4a9db209f44c3ec697cf531 --- /dev/null +++ b/configs/rotate/fcosr/_base_/fcosr_reader.yml @@ -0,0 +1,45 @@ +worker_num: 4 +image_height: &image_height 1024 +image_width: &image_width 1024 +image_size: &image_size [*image_height, *image_width] + +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Array: {} + - RandomRFlip: {} + - RandomRRotate: {angle_mode: 'value', angle: [0, 90, 180, -90]} + - RandomRRotate: {angle_mode: 'value', angle: [30, 60], rotate_prob: 0.5} + - RResize: {target_size: *image_size, keep_ratio: True, interp: 2} + - Poly2RBox: {filter_threshold: 2, filter_mode: 'edge', rbox_type: 'oc'} + batch_transforms: + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + - PadRGT: {} + - PadBatch: {pad_to_stride: 32} + batch_size: 4 + shuffle: true + drop_last: true + use_shared_memory: true + collate_batch: true + +EvalReader: + sample_transforms: + - Decode: {} + - Poly2Array: {} + - RResize: {target_size: *image_size, keep_ratio: True, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: *image_size, keep_ratio: True, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 diff --git a/configs/rotate/fcosr/_base_/fcosr_x50.yml b/configs/rotate/fcosr/_base_/fcosr_x50.yml new file mode 100644 index 0000000000000000000000000000000000000000..77a4d8a2ff0594aa9f948111092fd6c625d13234 --- /dev/null +++ b/configs/rotate/fcosr/_base_/fcosr_x50.yml @@ -0,0 +1,44 @@ +architecture: YOLOv3 +snapshot_epoch: 1 +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNeXt50_32x4d_pretrained.pdparams + +YOLOv3: + backbone: ResNet + neck: FPN + yolo_head: FCOSRHead + post_process: ~ + +ResNet: + depth: 50 + groups: 32 + base_width: 4 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + extra_stage: 2 + has_extra_convs: true + use_c5: false + relu_before_extra_convs: true + +FCOSRHead: + feat_channels: 256 + fpn_strides: [8, 16, 32, 64, 128] + stacked_convs: 4 + loss_weight: {class: 1.0, probiou: 1.0} + assigner: + name: FCOSRAssigner + factor: 12 + threshold: 0.23 + boundary: [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 100000000.0]] + nms: + name: MultiClassNMS + nms_top_k: 2000 + keep_top_k: -1 + score_threshold: 0.1 + nms_threshold: 0.1 + normalized: False diff --git a/configs/rotate/fcosr/_base_/optimizer_3x.yml b/configs/rotate/fcosr/_base_/optimizer_3x.yml new file mode 100644 index 0000000000000000000000000000000000000000..859db126bed27471f6d8dcd02761299395ce9468 --- /dev/null +++ b/configs/rotate/fcosr/_base_/optimizer_3x.yml @@ -0,0 +1,20 @@ +epoch: 36 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0.3333333 + steps: 500 + +OptimizerBuilder: + clip_grad_by_norm: 35. + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/rotate/fcosr/fcosr_x50_3x_dota.yml b/configs/rotate/fcosr/fcosr_x50_3x_dota.yml new file mode 100644 index 0000000000000000000000000000000000000000..d9554d30896ca5e2a3a5eb03725f1f6bb97a7dfc --- /dev/null +++ b/configs/rotate/fcosr/fcosr_x50_3x_dota.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../../datasets/dota.yml', + '../../runtime.yml', + '_base_/optimizer_3x.yml', + '_base_/fcosr_reader.yml', + '_base_/fcosr_x50.yml' +] + +weights: output/fcosr_x50_3x_dota/model_final diff --git a/deploy/python/infer.py b/deploy/python/infer.py index bc503941cdc86c4553ec1db4b2c4c72fc7d8f0ab..fc84bcd60484ea9340e20d62946008a07a08c3d8 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -856,13 +856,16 @@ def load_predictor(model_dir, if use_dynamic_shape: min_input_shape = { - 'image': [batch_size, 3, trt_min_shape, trt_min_shape] + 'image': [batch_size, 3, trt_min_shape, trt_min_shape], + 'scale_factor': [batch_size, 2] } max_input_shape = { - 'image': [batch_size, 3, trt_max_shape, trt_max_shape] + 'image': [batch_size, 3, trt_max_shape, trt_max_shape], + 'scale_factor': [batch_size, 2] } opt_input_shape = { - 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape], + 'scale_factor': [batch_size, 2] } config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape) diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py index ce5be21cd48272939fd0e2ea36c5db439cb02081..f23f0ddd4f7decf4f1c3d304f661c74884207dec 100644 --- a/ppdet/modeling/architectures/yolo.py +++ b/ppdet/modeling/architectures/yolo.py @@ -77,7 +77,10 @@ class YOLOv3(BaseArch): def _forward(self): body_feats = self.backbone(self.inputs) - neck_feats = self.neck(body_feats, self.for_mot) + if self.for_mot: + neck_feats = self.neck(body_feats, self.for_mot) + else: + neck_feats = self.neck(body_feats) if isinstance(neck_feats, dict): assert self.for_mot == True diff --git a/ppdet/modeling/assigners/__init__.py b/ppdet/modeling/assigners/__init__.py index f82266b925f8b940bbbcaf646959ea0254c0161f..ded98c9439cd896c99ca47bc3119d39effad3870 100644 --- a/ppdet/modeling/assigners/__init__.py +++ b/ppdet/modeling/assigners/__init__.py @@ -17,9 +17,11 @@ from . import task_aligned_assigner from . import atss_assigner from . import simota_assigner from . import max_iou_assigner +from . import fcosr_assigner from .utils import * from .task_aligned_assigner import * from .atss_assigner import * from .simota_assigner import * from .max_iou_assigner import * +from .fcosr_assigner import * diff --git a/ppdet/modeling/assigners/fcosr_assigner.py b/ppdet/modeling/assigners/fcosr_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..84f991023215b344e59c9f6e1e4f7643b3c00dc0 --- /dev/null +++ b/ppdet/modeling/assigners/fcosr_assigner.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ppdet.modeling.rbox_utils import box2corners, check_points_in_polys, paddle_gather + +__all__ = ['FCOSRAssigner'] + +EPS = 1e-9 + + +@register +class FCOSRAssigner(nn.Layer): + """ FCOSR Assigner, refer to https://arxiv.org/abs/2111.10780 for details + + 1. compute normalized gaussian distribution score and refined gaussian distribution score + 2. refer to ellipse center sampling, sample points whose normalized gaussian distribution score is greater than threshold + 3. refer to multi-level sampling, assign ground truth to feature map which follows two conditions. + i). first, the ratio between the short edge of the target and the stride of the feature map is less than 2. + ii). second, the long edge of minimum bounding rectangle of the target is larger than the acceptance range of feature map + 4. refer to fuzzy sample label assignment, the points satisfying 2 and 3 will be assigned to the ground truth according to gaussian distribution score + """ + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + factor=12, + threshold=0.23, + boundary=[[-1, 128], [128, 320], [320, 10000]], + score_type='iou'): + super(FCOSRAssigner, self).__init__() + self.num_classes = num_classes + self.factor = factor + self.threshold = threshold + self.boundary = [ + paddle.to_tensor( + l, dtype=paddle.float32).reshape([1, 1, 2]) for l in boundary + ] + self.score_type = score_type + + def get_gaussian_distribution_score(self, points, gt_rboxes, gt_polys): + # projecting points to coordinate system defined by each rbox + # [B, N, 4, 2] -> 4 * [B, N, 1, 2] + a, b, c, d = gt_polys.split(4, axis=2) + # [1, L, 2] -> [1, 1, L, 2] + points = points.unsqueeze(0) + ab = b - a + ad = d - a + # [B, N, 5] -> [B, N, 2], [B, N, 2], [B, N, 1] + xy, wh, angle = gt_rboxes.split([2, 2, 1], axis=-1) + # [B, N, 2] -> [B, N, 1, 2] + xy = xy.unsqueeze(2) + # vector of points to center [B, N, L, 2] + vec = points - xy + # = |ab| * |vec| * cos(theta) [B, N, L] + vec_dot_ab = paddle.sum(vec * ab, axis=-1) + # = |ad| * |vec| * cos(theta) [B, N, L] + vec_dot_ad = paddle.sum(vec * ad, axis=-1) + # norm_ab [B, N, L] + norm_ab = paddle.sum(ab * ab, axis=-1).sqrt() + # norm_ad [B, N, L] + norm_ad = paddle.sum(ad * ad, axis=-1).sqrt() + # min(h, w), [B, N, 1] + min_edge = paddle.min(wh, axis=-1, keepdim=True) + # delta_x, delta_y [B, N, L] + delta_x = vec_dot_ab.pow(2) / (norm_ab.pow(3) * min_edge + EPS) + delta_y = vec_dot_ad.pow(2) / (norm_ad.pow(3) * min_edge + EPS) + # score [B, N, L] + norm_score = paddle.exp(-0.5 * self.factor * (delta_x + delta_y)) + + # simplified calculation + sigma = min_edge / self.factor + refined_score = norm_score / (2 * np.pi * sigma + EPS) + return norm_score, refined_score + + def get_rotated_inside_mask(self, points, gt_polys, scores): + inside_mask = check_points_in_polys(points, gt_polys) + center_mask = scores >= self.threshold + return (inside_mask & center_mask).cast(paddle.float32) + + def get_inside_range_mask(self, points, gt_bboxes, gt_rboxes, stride_tensor, + regress_range): + # [1, L, 2] -> [1, 1, L, 2] + points = points.unsqueeze(0) + # [B, n, 4] -> [B, n, 1, 4] + x1y1, x2y2 = gt_bboxes.unsqueeze(2).split(2, axis=-1) + # [B, n, L, 2] + lt = points - x1y1 + rb = x2y2 - points + # [B, n, L, 4] + ltrb = paddle.concat([lt, rb], axis=-1) + # [B, n, L, 4] -> [B, n, L] + inside_mask = paddle.min(ltrb, axis=-1) > EPS + # regress_range [1, L, 2] -> [1, 1, L, 2] + regress_range = regress_range.unsqueeze(0) + # stride_tensor [1, L, 1] -> [1, 1, L] + stride_tensor = stride_tensor.transpose((0, 2, 1)) + # fcos range + # [B, n, L, 4] -> [B, n, L] + ltrb_max = paddle.max(ltrb, axis=-1) + # [1, 1, L, 2] -> [1, 1, L] + low, high = regress_range[..., 0], regress_range[..., 1] + # [B, n, L] + regress_mask = (ltrb_max >= low) & (ltrb_max <= high) + # mask for rotated + # [B, n, 1] + min_edge = paddle.min(gt_rboxes[..., 2:4], axis=-1, keepdim=True) + # [B, n , L] + rotated_mask = ((min_edge / stride_tensor) < 2.0) & (ltrb_max > high) + mask = inside_mask & (regress_mask | rotated_mask) + return mask.cast(paddle.float32) + + @paddle.no_grad() + def forward(self, + anchor_points, + stride_tensor, + num_anchors_list, + gt_labels, + gt_bboxes, + gt_rboxes, + pad_gt_mask, + bg_index, + pred_rboxes=None): + r""" + + Args: + anchor_points (Tensor, float32): pre-defined anchor points, shape(1, L, 2), + "x, y" format + stride_tensor (Tensor, float32): stride tensor, shape (1, L, 1) + num_anchors_list (List): num of anchors in each level + gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1) + gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4) + gt_rboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 5) + pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1) + bg_index (int): background index + pred_rboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 5) + Returns: + assigned_labels (Tensor): (B, L) + assigned_rboxes (Tensor): (B, L, 5) + assigned_scores (Tensor): (B, L, C), if pred_rboxes is not None, then output ious + """ + + _, num_anchors, _ = anchor_points.shape + batch_size, num_max_boxes, _ = gt_rboxes.shape + if num_max_boxes == 0: + assigned_labels = paddle.full( + [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype) + assigned_rboxes = paddle.zeros([batch_size, num_anchors, 5]) + assigned_scores = paddle.zeros( + [batch_size, num_anchors, self.num_classes]) + return assigned_labels, assigned_rboxes, assigned_scores + + # get normalized gaussian distribution score and refined distribution score + gt_polys = box2corners(gt_rboxes) + score, refined_score = self.get_gaussian_distribution_score( + anchor_points, gt_rboxes, gt_polys) + inside_mask = self.get_rotated_inside_mask(anchor_points, gt_polys, + score) + regress_ranges = [] + for num, bound in zip(num_anchors_list, self.boundary): + regress_ranges.append(bound.tile((1, num, 1))) + regress_ranges = paddle.concat(regress_ranges, axis=1) + regress_mask = self.get_inside_range_mask( + anchor_points, gt_bboxes, gt_rboxes, stride_tensor, regress_ranges) + # [B, n, L] + mask_positive = inside_mask * regress_mask * pad_gt_mask + refined_score = refined_score * mask_positive - (1. - mask_positive) + + argmax_refined_score = refined_score.argmax(axis=-2) + max_refined_score = refined_score.max(axis=-2) + assigned_gt_index = argmax_refined_score + + # assigned target + batch_ind = paddle.arange( + end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1) + assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes + assigned_labels = paddle.gather( + gt_labels.flatten(), assigned_gt_index.flatten(), axis=0) + assigned_labels = assigned_labels.reshape([batch_size, num_anchors]) + assigned_labels = paddle.where( + max_refined_score > 0, assigned_labels, + paddle.full_like(assigned_labels, bg_index)) + + assigned_rboxes = paddle.gather( + gt_rboxes.reshape([-1, 5]), assigned_gt_index.flatten(), axis=0) + assigned_rboxes = assigned_rboxes.reshape([batch_size, num_anchors, 5]) + + assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1) + ind = list(range(self.num_classes + 1)) + ind.remove(bg_index) + assigned_scores = paddle.index_select( + assigned_scores, paddle.to_tensor(ind), axis=-1) + + if self.score_type == 'gaussian': + selected_scores = paddle_gather( + score, 1, argmax_refined_score.unsqueeze(-2)).squeeze(-2) + assigned_scores = assigned_scores * selected_scores.unsqueeze(-1) + elif self.score_type == 'iou': + assert pred_rboxes is not None, 'If score type is iou, pred_rboxes should not be None' + from ext_op import matched_rbox_iou + b, l = pred_rboxes.shape[:2] + iou_score = matched_rbox_iou( + pred_rboxes.reshape((-1, 5)), assigned_rboxes.reshape( + (-1, 5))).reshape((b, l, 1)) + assigned_scores = assigned_scores * iou_score + + return assigned_labels, assigned_rboxes, assigned_scores \ No newline at end of file diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index e831933381c88dbd14f588b045d619333cc537c3..36cacbdec5d708a8e0c3f29362d9c70f294e99a6 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -33,6 +33,7 @@ from . import sparsercnn_head from . import tood_head from . import retina_head from . import ppyoloe_head +from . import fcosr_head from .bbox_head import * from .mask_head import * @@ -55,3 +56,4 @@ from .sparsercnn_head import * from .tood_head import * from .retina_head import * from .ppyoloe_head import * +from .fcosr_head import * diff --git a/ppdet/modeling/heads/fcosr_head.py b/ppdet/modeling/heads/fcosr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..06b84440e8eb1f8e252eaf2c723bbc03bb4ced0a --- /dev/null +++ b/ppdet/modeling/heads/fcosr_head.py @@ -0,0 +1,395 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from .fcos_head import ScaleReg +from ..initializer import bias_init_with_prob, constant_, normal_ +from ..ops import get_act_fn, anchor_generator +from ..rbox_utils import box2corners +from ..losses import ProbIoULoss +import numpy as np + +__all__ = ['FCOSRHead'] + + +def trunc_div(a, b): + ipt = paddle.divide(a, b) + sign_ipt = paddle.sign(ipt) + abs_ipt = paddle.abs(ipt) + abs_ipt = paddle.floor(abs_ipt) + out = paddle.multiply(sign_ipt, abs_ipt) + return out + + +def fmod(a, b): + return a - trunc_div(a, b) * b + + +def fmod_eval(a, b): + return a - a.divide(b).cast(paddle.int32).cast(paddle.float32) * b + + +class ConvBNLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=1, + groups=1, + padding=0, + norm_cfg={'name': 'gn', + 'num_groups': 32}, + act=None): + super(ConvBNLayer, self).__init__() + + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=False) + + norm_type = norm_cfg['name'] + if norm_type in ['sync_bn', 'bn']: + self.norm = nn.BatchNorm2D( + ch_out, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + else: + groups = norm_cfg.get('num_groups', 1) + self.norm = nn.GroupNorm( + num_groups=groups, + num_channels=ch_out, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.act = get_act_fn(act) if act is None or isinstance(act, ( + str, dict)) else act + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + + return x + + +@register +class FCOSRHead(nn.Layer): + """ FCOSR Head, refer to https://arxiv.org/abs/2111.10780 for details """ + + __shared__ = ['num_classes', 'trt'] + __inject__ = ['assigner', 'nms'] + + def __init__(self, + num_classes=15, + in_channels=256, + feat_channels=256, + stacked_convs=4, + act='relu', + fpn_strides=[4, 8, 16, 32, 64], + trt=False, + loss_weight={'class': 1.0, + 'probiou': 1.0}, + norm_cfg={'name': 'gn', + 'num_groups': 32}, + assigner='FCOSRAssigner', + nms='MultiClassNMS'): + + super(FCOSRHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.fpn_strides = fpn_strides + self.stacked_convs = stacked_convs + self.loss_weight = loss_weight + self.half_pi = paddle.to_tensor( + [1.5707963267948966], dtype=paddle.float32) + self.probiou_loss = ProbIoULoss(mode='l1') + act = get_act_fn( + act, trt=trt) if act is None or isinstance(act, + (str, dict)) else act + self.trt = trt + self.loss_weight = loss_weight + self.assigner = assigner + self.nms = nms + # stem + self.stem_cls = nn.LayerList() + self.stem_reg = nn.LayerList() + for i in range(self.stacked_convs): + self.stem_cls.append( + ConvBNLayer( + self.in_channels[i], + feat_channels, + filter_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act=act)) + self.stem_reg.append( + ConvBNLayer( + self.in_channels[i], + feat_channels, + filter_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act=act)) + + self.scales = nn.LayerList( + [ScaleReg() for _ in range(len(fpn_strides))]) + + # prediction + self.pred_cls = nn.Conv2D(feat_channels, self.num_classes, 3, padding=1) + + self.pred_xy = nn.Conv2D(feat_channels, 2, 3, padding=1) + + self.pred_wh = nn.Conv2D(feat_channels, 2, 3, padding=1) + + self.pred_angle = nn.Conv2D(feat_channels, 1, 3, padding=1) + + self._init_weights() + + def _init_weights(self): + for cls_, reg_ in zip(self.stem_cls, self.stem_reg): + normal_(cls_.conv.weight, std=0.01) + normal_(reg_.conv.weight, std=0.01) + + bias_cls = bias_init_with_prob(0.01) + normal_(self.pred_cls.weight, std=0.01) + constant_(self.pred_cls.bias, bias_cls) + normal_(self.pred_xy.weight, std=0.01) + normal_(self.pred_wh.weight, std=0.01) + normal_(self.pred_angle.weight, std=0.01) + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } + + def _generate_anchors(self, feats): + if self.trt: + anchor_points = [] + for feat, stride in zip(feats, self.fpn_strides): + _, _, h, w = paddle.shape(feat) + anchor, _ = anchor_generator( + feat, + stride * 4, + 1.0, [1.0, 1.0, 1.0, 1.0], [stride, stride], + offset=0.5) + x1, y1, x2, y2 = paddle.split(anchor, 4, axis=-1) + xc = (x1 + x2 + 1) / 2 + yc = (y1 + y2 + 1) / 2 + anchor_point = paddle.concat( + [xc, yc], axis=-1).reshape((1, h * w, 2)) + anchor_points.append(anchor_point) + anchor_points = paddle.concat(anchor_points, axis=1) + return anchor_points, None, None + else: + anchor_points = [] + stride_tensor = [] + num_anchors_list = [] + for i, stride in enumerate(self.fpn_strides): + _, _, h, w = feats[i].shape + shift_x = (paddle.arange(end=w) + 0.5) * stride + shift_y = (paddle.arange(end=h) + 0.5) * stride + shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) + anchor_point = paddle.cast( + paddle.stack( + [shift_x, shift_y], axis=-1), dtype='float32') + anchor_points.append(anchor_point.reshape([1, -1, 2])) + stride_tensor.append( + paddle.full( + [1, h * w, 1], stride, dtype='float32')) + num_anchors_list.append(h * w) + anchor_points = paddle.concat(anchor_points, axis=1) + stride_tensor = paddle.concat(stride_tensor, axis=1) + return anchor_points, stride_tensor, num_anchors_list + + def forward(self, feats, target=None): + if self.training: + return self.forward_train(feats, target) + else: + return self.forward_eval(feats, target) + + def forward_train(self, feats, target=None): + anchor_points, stride_tensor, num_anchors_list = self._generate_anchors( + feats) + cls_pred_list, reg_pred_list = [], [] + for stride, feat, scale in zip(self.fpn_strides, feats, self.scales): + # cls + cls_feat = feat + for cls_layer in self.stem_cls: + cls_feat = cls_layer(cls_feat) + cls_pred = F.sigmoid(self.pred_cls(cls_feat)) + cls_pred_list.append(cls_pred.flatten(2).transpose((0, 2, 1))) + # reg + reg_feat = feat + for reg_layer in self.stem_reg: + reg_feat = reg_layer(reg_feat) + + reg_xy = scale(self.pred_xy(reg_feat)) * stride + reg_wh = F.elu(scale(self.pred_wh(reg_feat)) + 1.) * stride + reg_angle = self.pred_angle(reg_feat) + reg_angle = fmod(reg_angle, self.half_pi) + reg_pred = paddle.concat([reg_xy, reg_wh, reg_angle], axis=1) + reg_pred_list.append(reg_pred.flatten(2).transpose((0, 2, 1))) + + cls_pred_list = paddle.concat(cls_pred_list, axis=1) + reg_pred_list = paddle.concat(reg_pred_list, axis=1) + + return self.get_loss([ + cls_pred_list, reg_pred_list, anchor_points, stride_tensor, + num_anchors_list + ], target) + + def forward_eval(self, feats, target=None): + cls_pred_list, reg_pred_list = [], [] + anchor_points, _, _ = self._generate_anchors(feats) + for stride, feat, scale in zip(self.fpn_strides, feats, self.scales): + b, _, h, w = paddle.shape(feat) + # cls + cls_feat = feat + for cls_layer in self.stem_cls: + cls_feat = cls_layer(cls_feat) + cls_pred = F.sigmoid(self.pred_cls(cls_feat)) + cls_pred_list.append(cls_pred.reshape([b, self.num_classes, h * w])) + # reg + reg_feat = feat + for reg_layer in self.stem_reg: + reg_feat = reg_layer(reg_feat) + + reg_xy = scale(self.pred_xy(reg_feat)) * stride + reg_wh = F.elu(scale(self.pred_wh(reg_feat)) + 1.) * stride + reg_angle = self.pred_angle(reg_feat) + reg_angle = fmod_eval(reg_angle, self.half_pi) + reg_pred = paddle.concat([reg_xy, reg_wh, reg_angle], axis=1) + reg_pred = reg_pred.reshape([b, 5, h * w]).transpose((0, 2, 1)) + reg_pred_list.append(reg_pred) + + cls_pred_list = paddle.concat(cls_pred_list, axis=2) + reg_pred_list = paddle.concat(reg_pred_list, axis=1) + reg_pred_list = self._bbox_decode(anchor_points, reg_pred_list) + return cls_pred_list, reg_pred_list + + def _bbox_decode(self, points, reg_pred_list): + xy, wha = paddle.split(reg_pred_list, [2, 3], axis=-1) + xy = xy + points + return paddle.concat([xy, wha], axis=-1) + + def _box2corners(self, pred_bboxes): + """ convert (x, y, w, h, angle) to (x1, y1, x2, y2, x3, y3, x4, y4) + + Args: + pred_bboxes (Tensor): [B, N, 5] + + Returns: + polys (Tensor): [B, N, 8] + """ + x, y, w, h, angle = paddle.split(pred_bboxes, 5, axis=-1) + cos_a_half = paddle.cos(angle) * 0.5 + sin_a_half = paddle.sin(angle) * 0.5 + w_x = cos_a_half * w + w_y = sin_a_half * w + h_x = -sin_a_half * h + h_y = cos_a_half * h + return paddle.concat( + [ + x + w_x + h_x, y + w_y + h_y, x - w_x + h_x, y - w_y + h_y, + x - w_x - h_x, y - w_y - h_y, x + w_x - h_x, y + w_y - h_y + ], + axis=-1) + + def get_loss(self, head_outs, gt_meta): + cls_pred_list, reg_pred_list, anchor_points, stride_tensor, num_anchors_list = head_outs + gt_labels = gt_meta['gt_class'] + gt_bboxes = gt_meta['gt_bbox'] + gt_rboxes = gt_meta['gt_rbox'] + pad_gt_mask = gt_meta['pad_gt_mask'] + # decode + pred_rboxes = self._bbox_decode(anchor_points, reg_pred_list) + # label assignment + assigned_labels, assigned_rboxes, assigned_scores = \ + self.assigner( + anchor_points, + stride_tensor, + num_anchors_list, + gt_labels, + gt_bboxes, + gt_rboxes, + pad_gt_mask, + self.num_classes, + pred_rboxes + ) + + # reg_loss + mask_positive = (assigned_labels != self.num_classes) + num_pos = mask_positive.sum().item() + if num_pos > 0: + bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 5]) + pred_rboxes_pos = paddle.masked_select(pred_rboxes, + bbox_mask).reshape([-1, 5]) + assigned_rboxes_pos = paddle.masked_select( + assigned_rboxes, bbox_mask).reshape([-1, 5]) + bbox_weight = paddle.masked_select( + assigned_scores.sum(-1), mask_positive).reshape([-1]) + avg_factor = bbox_weight.sum() + loss_probiou = self.probiou_loss(pred_rboxes_pos, + assigned_rboxes_pos) + loss_probiou = paddle.sum(loss_probiou * bbox_weight) / avg_factor + else: + loss_probiou = pred_rboxes.sum() * 0. + + avg_factor = max(num_pos, 1.0) + # cls_loss + loss_cls = self._qfocal_loss( + cls_pred_list, assigned_scores, reduction='sum') + loss_cls = loss_cls / avg_factor + + loss = self.loss_weight['class'] * loss_cls + \ + self.loss_weight['probiou'] * loss_probiou + out_dict = { + 'loss': loss, + 'loss_probiou': loss_probiou, + 'loss_cls': loss_cls + } + return out_dict + + @staticmethod + def _qfocal_loss(score, label, gamma=2.0, reduction='sum'): + weight = (score - label).pow(gamma) + loss = F.binary_cross_entropy( + score, label, weight=weight, reduction=reduction) + return loss + + def post_process(self, head_outs, scale_factor): + pred_scores, pred_rboxes = head_outs + # [B, N, 5] -> [B, N, 4, 2] -> [B, N, 8] + pred_rboxes = self._box2corners(pred_rboxes) + # scale bbox to origin + scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) + scale_factor = paddle.concat( + [ + scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x, + scale_y + ], + axis=-1).reshape([-1, 1, 8]) + pred_rboxes /= scale_factor + bbox_pred, bbox_num, _ = self.nms(pred_rboxes, pred_scores) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 0e6ebe9069ea0671bce74ea4496863f0cb052803..0c946d92827f61a9ac132587413ac6c8fc0df89e 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -27,6 +27,7 @@ from . import detr_loss from . import sparsercnn_loss from . import focal_loss from . import smooth_l1_loss +from . import probiou_loss from .yolo_loss import * from .iou_aware_loss import * @@ -44,3 +45,4 @@ from .sparsercnn_loss import * from .focal_loss import * from .smooth_l1_loss import * from .pose3d_loss import * +from .probiou_loss import * diff --git a/ppdet/modeling/losses/probiou_loss.py b/ppdet/modeling/losses/probiou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a1c75879ee02ea8cd721272ad1f12fb3b96a67 --- /dev/null +++ b/ppdet/modeling/losses/probiou_loss.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import paddle +import paddle.nn.functional as F + +from ppdet.core.workspace import register, serializable + +__all__ = ['ProbIoULoss'] + + +def gbb_form(boxes): + xy, wh, angle = paddle.split(boxes, [2, 2, 1], axis=-1) + return paddle.concat([xy, wh.pow(2) / 12., angle], axis=-1) + + +def rotated_form(a_, b_, angles): + cos_a = paddle.cos(angles) + sin_a = paddle.sin(angles) + a = a_ * paddle.pow(cos_a, 2) + b_ * paddle.pow(sin_a, 2) + b = a_ * paddle.pow(sin_a, 2) + b_ * paddle.pow(cos_a, 2) + c = (a_ - b_) * cos_a * sin_a + return a, b, c + + +def probiou_loss(pred, target, eps=1e-3, mode='l1'): + """ + pred -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours predicted box ;in case of HBB angle == 0 + target -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours target box ;in case of HBB angle == 0 + eps -> threshold to avoid infinite values + mode -> ('l1' in [0,1] or 'l2' in [0,inf]) metrics according our paper + + """ + + gbboxes1 = gbb_form(pred) + gbboxes2 = gbb_form(target) + + x1, y1, a1_, b1_, c1_ = gbboxes1[:, + 0], gbboxes1[:, + 1], gbboxes1[:, + 2], gbboxes1[:, + 3], gbboxes1[:, + 4] + x2, y2, a2_, b2_, c2_ = gbboxes2[:, + 0], gbboxes2[:, + 1], gbboxes2[:, + 2], gbboxes2[:, + 3], gbboxes2[:, + 4] + + a1, b1, c1 = rotated_form(a1_, b1_, c1_) + a2, b2, c2 = rotated_form(a2_, b2_, c2_) + + t1 = 0.25 * ((a1 + a2) * (paddle.pow(y1 - y2, 2)) + (b1 + b2) * (paddle.pow(x1 - x2, 2))) + \ + 0.5 * ((c1+c2)*(x2-x1)*(y1-y2)) + t2 = (a1 + a2) * (b1 + b2) - paddle.pow(c1 + c2, 2) + t3_ = (a1 * b1 - c1 * c1) * (a2 * b2 - c2 * c2) + t3 = 0.5 * paddle.log(t2 / (4 * paddle.sqrt(F.relu(t3_)) + eps)) + + B_d = (t1 / t2) + t3 + # B_d = t1 + t2 + t3 + + B_d = paddle.clip(B_d, min=eps, max=100.0) + l1 = paddle.sqrt(1.0 - paddle.exp(-B_d) + eps) + l_i = paddle.pow(l1, 2.0) + l2 = -paddle.log(1.0 - l_i + eps) + + if mode == 'l1': + probiou = l1 + if mode == 'l2': + probiou = l2 + + return probiou + + +@serializable +@register +class ProbIoULoss(object): + """ ProbIoU Loss, refer to https://arxiv.org/abs/2106.06072 for details """ + + def __init__(self, mode='l1', eps=1e-3): + super(ProbIoULoss, self).__init__() + self.mode = mode + self.eps = eps + + def __call__(self, pred_rboxes, assigned_rboxes): + return probiou_loss(pred_rboxes, assigned_rboxes, self.eps, self.mode) diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 025b60050237824673b6070a3d910a1895904270..fb9d98cf0f35458eb2af063487b7664a3fd8c2cc 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -37,6 +37,7 @@ __all__ = [ 'silu', 'swish', 'identity', + 'anchor_generator' ] @@ -117,6 +118,101 @@ def batch_norm(ch, return norm_layer +@paddle.jit.not_to_static +def anchor_generator(input, + anchor_sizes=None, + aspect_ratios=None, + variance=[0.1, 0.1, 0.2, 0.2], + stride=None, + offset=0.5): + """ + **Anchor generator operator** + Generate anchors for Faster RCNN algorithm. + Each position of the input produce N anchors, N = + size(anchor_sizes) * size(aspect_ratios). The order of generated anchors + is firstly aspect_ratios loop then anchor_sizes loop. + Args: + input(Variable): 4-D Tensor with shape [N,C,H,W]. The input feature map. + anchor_sizes(float32|list|tuple, optional): The anchor sizes of generated + anchors, given in absolute pixels e.g. [64., 128., 256., 512.]. + For instance, the anchor size of 64 means the area of this anchor + equals to 64**2. None by default. + aspect_ratios(float32|list|tuple, optional): The height / width ratios + of generated anchors, e.g. [0.5, 1.0, 2.0]. None by default. + variance(list|tuple, optional): The variances to be used in box + regression deltas. The data type is float32, [0.1, 0.1, 0.2, 0.2] by + default. + stride(list|tuple, optional): The anchors stride across width and height. + The data type is float32. e.g. [16.0, 16.0]. None by default. + offset(float32, optional): Prior boxes center offset. 0.5 by default. + Returns: + Tuple: + Anchors(Variable): The output anchors with a layout of [H, W, num_anchors, 4]. + H is the height of input, W is the width of input, + num_anchors is the box count of each position. + Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized. + + Variances(Variable): The expanded variances of anchors + with a layout of [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_anchors is the box count of each position. + Each variance is in (xcenter, ycenter, w, h) format. + Examples: + .. code-block:: python + import paddle.fluid as fluid + conv1 = fluid.data(name='conv1', shape=[None, 48, 16, 16], dtype='float32') + anchor, var = fluid.layers.anchor_generator( + input=conv1, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + """ + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(anchor_sizes): + anchor_sizes = [anchor_sizes] + if not _is_list_or_tuple_(aspect_ratios): + aspect_ratios = [aspect_ratios] + if not (_is_list_or_tuple_(stride) and len(stride) == 2): + raise ValueError('stride should be a list or tuple ', + 'with length 2, (stride_width, stride_height).') + + anchor_sizes = list(map(float, anchor_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + stride = list(map(float, stride)) + + if in_dynamic_mode(): + attrs = ('anchor_sizes', anchor_sizes, 'aspect_ratios', aspect_ratios, + 'variances', variance, 'stride', stride, 'offset', offset) + anchor, var = C_ops.anchor_generator(input, *attrs) + return anchor, var + + helper = LayerHelper("anchor_generator", **locals()) + dtype = helper.input_dtype() + attrs = { + 'anchor_sizes': anchor_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'stride': stride, + 'offset': offset + } + + anchor = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="anchor_generator", + inputs={"Input": input}, + outputs={"Anchors": anchor, + "Variances": var}, + attrs=attrs, ) + anchor.stop_gradient = True + var.stop_gradient = True + return anchor, var + @paddle.jit.not_to_static def distribute_fpn_proposals(fpn_rois, diff --git a/ppdet/modeling/rbox_utils.py b/ppdet/modeling/rbox_utils.py index 19bca8d8a124b66fde39959ff36db4ff24680a58..bde5320cb74ed85451b17a84016f314ac07398a7 100644 --- a/ppdet/modeling/rbox_utils.py +++ b/ppdet/modeling/rbox_utils.py @@ -157,3 +157,85 @@ def rbox2poly_np(rboxes): polys.append(poly) polys = np.array(polys) return polys + + +# rbox function implemented using paddle +def box2corners(box): + """convert box coordinate to corners + Args: + box (Tensor): (B, N, 5) with (x, y, w, h, alpha) angle is in [0, 90) + Returns: + corners (Tensor): (B, N, 4, 2) with (x1, y1, x2, y2, x3, y3, x4, y4) + """ + B = box.shape[0] + x, y, w, h, alpha = paddle.split(box, 5, axis=-1) + x4 = paddle.to_tensor( + [0.5, 0.5, -0.5, -0.5], dtype=paddle.float32).reshape( + (1, 1, 4)) # (1,1,4) + x4 = x4 * w # (B, N, 4) + y4 = paddle.to_tensor( + [-0.5, 0.5, 0.5, -0.5], dtype=paddle.float32).reshape((1, 1, 4)) + y4 = y4 * h # (B, N, 4) + corners = paddle.stack([x4, y4], axis=-1) # (B, N, 4, 2) + sin = paddle.sin(alpha) + cos = paddle.cos(alpha) + row1 = paddle.concat([cos, sin], axis=-1) + row2 = paddle.concat([-sin, cos], axis=-1) # (B, N, 2) + rot_T = paddle.stack([row1, row2], axis=-2) # (B, N, 2, 2) + rotated = paddle.bmm(corners.reshape([-1, 4, 2]), rot_T.reshape([-1, 2, 2])) + rotated = rotated.reshape([B, -1, 4, 2]) # (B*N, 4, 2) -> (B, N, 4, 2) + rotated[..., 0] += x + rotated[..., 1] += y + return rotated + + +def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + + +def check_points_in_polys(points, polys): + """Check whether point is in rotated boxes + Args: + points (tensor): (1, L, 2) anchor points + polys (tensor): [B, N, 4, 2] gt_polys + eps (float): default 1e-9 + Returns: + is_in_polys (tensor): (B, N, L) + """ + # [1, L, 2] -> [1, 1, L, 2] + points = points.unsqueeze(0) + # [B, N, 4, 2] -> [B, N, 1, 2] + a, b, c, d = polys.split(4, axis=2) + ab = b - a + ad = d - a + # [B, N, L, 2] + ap = points - a + # [B, N, 1] + norm_ab = paddle.sum(ab * ab, axis=-1) + # [B, N, 1] + norm_ad = paddle.sum(ad * ad, axis=-1) + # [B, N, L] dot product + ap_dot_ab = paddle.sum(ap * ab, axis=-1) + # [B, N, L] dot product + ap_dot_ad = paddle.sum(ap * ad, axis=-1) + # [B, N, L] = |A|*|B|*cos(theta) + is_in_polys = (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & ( + ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) + return is_in_polys