diff --git a/configs/faster_rcnn/README.md b/configs/faster_rcnn/README.md index 644205a73d47e9eb3f843db7e333c2e7320d95c6..33a34d4de72ce5221624aaab48fd6b4f47124d3b 100644 --- a/configs/faster_rcnn/README.md +++ b/configs/faster_rcnn/README.md @@ -8,6 +8,7 @@ | ResNet50-vd | Faster | 1 | 1x | ---- | 37.6 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r50_vd_1x_coco.yml) | | ResNet101 | Faster | 1 | 1x | ---- | 39.0 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r101_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r101_1x_coco.yml) | | ResNet34-FPN | Faster | 1 | 1x | ---- | 37.8 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r34_fpn_1x_coco.yml) | +| ResNet34-FPN-MultiScaleTest | Faster | 1 | 1x | ---- | 38.2 | - | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml) | | ResNet34-vd-FPN | Faster | 1 | 1x | ---- | 38.5 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r34_vd_fpn_1x_coco.yml) | | ResNet50-FPN | Faster | 1 | 1x | ---- | 38.4 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.yml) | | ResNet50-FPN | Faster | 1 | 2x | ---- | 40.0 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml) | diff --git a/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml b/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..559d5f1fe9fdcbf42189383a69f9d1a056792cda --- /dev/null +++ b/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml @@ -0,0 +1,22 @@ +_BASE_: [ + 'faster_rcnn_r34_fpn_1x_coco.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_pretrained.pdparams +weights: output/faster_rcnn_r34_fpn_multiscaletest_1x_coco/model_final + +EvalReader: + sample_transforms: + - Decode: {} +# - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900], use_flip: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} +# - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900], use_flip: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} \ No newline at end of file diff --git a/docs/tutorials/config_annotation/multi_scale_test_config.md b/docs/tutorials/config_annotation/multi_scale_test_config.md new file mode 100644 index 0000000000000000000000000000000000000000..1b6b6bb1fd4d08e696ad8d0d729e18207f3220d8 --- /dev/null +++ b/docs/tutorials/config_annotation/multi_scale_test_config.md @@ -0,0 +1,45 @@ +# Multi Scale Test Configuration + +Tags: Configuration + +--- +```yaml + +##################################### Multi scale test configuration ##################################### + +EvalReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} +``` + +--- + +Multi Scale Test is a TTA (Test Time Augmentation) method, it can improve object detection performance. + +The input image will be scaled into different scales, then model generated predictions (bboxes) at different scales, finally all the predictions will be combined to generate final prediction. (Here **NMS** is used to aggregate the predictions.) + +## _MultiscaleTestResize_ option + +`MultiscaleTestResize` option is used to enable multi scale test prediction. + +`origin_target_size: [800, 1333]` means the input image will be scaled to 800 (for short edge) and 1333 (max edge length cannot be greater than 1333) at first + +`target_size: [700 , 900]` property is used to specify different scales. + +It can be plugged into evaluation process or test (inference) process, by adding `MultiscaleTestResize` entry to `EvalReader.sample_transforms` or `TestReader.sample_transforms` + +--- + +###Note + +Now only CascadeRCNN, FasterRCNN and MaskRCNN are supported for multi scale testing. And batch size must be 1. \ No newline at end of file diff --git a/docs/tutorials/config_annotation/multi_scale_test_config_cn.md b/docs/tutorials/config_annotation/multi_scale_test_config_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..36851def51e7ae3a414b78df656100b5072685c0 --- /dev/null +++ b/docs/tutorials/config_annotation/multi_scale_test_config_cn.md @@ -0,0 +1,45 @@ +# 多尺度测试的配置 + +标签: 配置 + +--- +```yaml + +##################################### 多尺度测试的配置 ##################################### + +EvalReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} +``` + +--- + +多尺度测试是一种TTA方法(测试时增强),可以用于提高目标检测的准确率 + +输入图像首先被缩放为不同尺度的图像,然后模型对这些不同尺度的图像进行预测,最后将这些不同尺度上的预测结果整合为最终预测结果。(这里使用了**NMS**来整合不同尺度的预测结果) + +## _MultiscaleTestResize_ 选项 + +`MultiscaleTestResize` 选项用于开启多尺度测试. + +`origin_target_size: [800, 1333]` 项代表输入图像首先缩放为短边为800,最长边不超过1333. + +`target_size: [700 , 900]` 项设置不同的预测尺度。 + +通过在`EvalReader.sample_transforms`或`TestReader.sample_transforms`中设置`MultiscaleTestResize`项,可以在评估过程或预测过程中开启多尺度测试。 + +--- + +###注意 + +目前多尺度测试只支持CascadeRCNN, FasterRCNN and MaskRCNN网络, 并且batch size需要是1. \ No newline at end of file diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index e43fb7d20d00050bebf8462cafb2125296104a40..c731a5fb3457196f78f17ca64ac730baf7ebf366 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import typing + try: from collections.abc import Sequence except Exception: @@ -69,7 +71,13 @@ class PadBatch(BaseOperator): """ coarsest_stride = self.pad_to_stride - max_shape = np.array([data['image'].shape for data in samples]).max( + # multi scale input is nested list + if isinstance(samples, typing.Sequence) and len(samples) > 0 and isinstance(samples[0], typing.Sequence): + inner_samples = samples[0] + else: + inner_samples = samples + + max_shape = np.array([data['image'].shape for data in inner_samples]).max( axis=0) if coarsest_stride > 0: max_shape[1] = int( @@ -77,7 +85,7 @@ class PadBatch(BaseOperator): max_shape[2] = int( np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) - for data in samples: + for data in inner_samples: im = data['image'] im_c, im_h, im_w = im.shape[:] padding_im = np.zeros( diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index dc739ff6217966c4be4f4dbb18dd0697b5ca538b..455e74474d40daebf05d59ed419ba536a77a14d5 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -22,6 +22,7 @@ import copy import time import numpy as np +import typing from PIL import Image, ImageOps import paddle @@ -473,7 +474,11 @@ class Trainer(object): for metric in self._metrics: metric.update(data, outs) - sample_num += data['im_id'].numpy().shape[0] + # multi-scale inputs: all inputs have same im_id + if isinstance(data, typing.Sequence): + sample_num += data[0]['im_id'].numpy().shape[0] + else: + sample_num += data['im_id'].numpy().shape[0] self._compose_callback.on_step_end(self.status) self.status['sample_num'] = sample_num @@ -517,7 +522,10 @@ class Trainer(object): outs = self.model(data) for key in ['im_shape', 'scale_factor', 'im_id']: - outs[key] = data[key] + if isinstance(data, typing.Sequence): + outs[key] = data[0][key] + else: + outs[key] = data[key] for key, value in outs.items(): if hasattr(value, 'numpy'): outs[key] = value.numpy() diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index f9913b7fb81713d87e26e176f80b7d26848808d5..3925267d7f9e5656033e4c851d8b52f1031867ab 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -21,6 +21,7 @@ import sys import json import paddle import numpy as np +import typing from .map_utils import prune_zero_padding, DetectionMAP from .coco_utils import get_infer_results, cocoapi_eval @@ -98,7 +99,11 @@ class COCOMetric(Metric): for k, v in outputs.items(): outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v - im_id = inputs['im_id'] + # multi-scale inputs: all inputs have same im_id + if isinstance(inputs, typing.Sequence): + im_id = inputs[0]['im_id'] + else: + im_id = inputs['im_id'] outs['im_id'] = im_id.numpy() if isinstance(im_id, paddle.Tensor) else im_id diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py index d9875e18395f8ec5724d9887d0becf899b0ee1b7..d01c34735653050c7b78799f06cbf72b85979a61 100644 --- a/ppdet/modeling/architectures/meta_arch.py +++ b/ppdet/modeling/architectures/meta_arch.py @@ -2,9 +2,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import paddle import paddle.nn as nn +import typing + from ppdet.core.workspace import register +from ppdet.modeling.post_process import nms __all__ = ['BaseArch'] @@ -53,7 +57,53 @@ class BaseArch(nn.Layer): if self.training: out = self.get_loss() else: - out = self.get_pred() + inputs_list = [] + # multi-scale input + if not isinstance(inputs, typing.Sequence): + inputs_list.append(inputs) + else: + inputs_list.extend(inputs) + + outs = [] + for inp in inputs_list: + self.inputs = inp + outs.append(self.get_pred()) + + # multi-scale test + if len(outs)>1: + out = self.merge_multi_scale_predictions(outs) + else: + out = outs[0] + return out + + def merge_multi_scale_predictions(self, outs): + # default values for architectures not included in following list + num_classes = 80 + nms_threshold = 0.5 + keep_top_k = 100 + + if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'): + num_classes = self.bbox_head.num_classes + keep_top_k = self.bbox_post_process.nms.keep_top_k + nms_threshold = self.bbox_post_process.nms.nms_threshold + else: + raise Exception("Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now") + + final_boxes = [] + all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy() + for c in range(num_classes): + idxs = all_scale_outs[:, 0] == c + if np.count_nonzero(idxs) == 0: + continue + r = nms(all_scale_outs[idxs, 1:], nms_threshold) + final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1)) + out = np.concatenate(final_boxes) + out = np.concatenate(sorted(out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) + out = { + 'bbox': paddle.to_tensor(out), + 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ])) + } + return out def build_inputs(self, data, input_def): diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 679e09134bf28268d3112928b09b5b3c385a3c5b..f485abaf6af1ff5dbbae23ae67c3cebc8a4c8745 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -654,3 +654,59 @@ class SparsePostProcess(object): bbox_pred = paddle.concat(boxes_final) return bbox_pred, bbox_num + + +def nms(dets, thresh): + """Apply classic DPM-style greedy NMS.""" + if dets.shape[0] == 0: + return dets[[], :] + scores = dets[:, 0] + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + ndets = dets.shape[0] + suppressed = np.zeros((ndets), dtype=np.int) + + # nominal indices + # _i, _j + # sorted indices + # i, j + # temp variables for box i's (the box currently under consideration) + # ix1, iy1, ix2, iy2, iarea + + # variables for computing overlap with box j (lower scoring box) + # xx1, yy1, xx2, yy2 + # w, h + # inter, ovr + + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + keep = np.where(suppressed == 0)[0] + dets = dets[keep, :] + return dets diff --git a/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..19023f718333c56c70776c79201dc03d742c1ed3 Binary files /dev/null and b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg differ diff --git a/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2a17e0c6ee400dcba762c4d56dea03d7e124b9c5 Binary files /dev/null and b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg differ diff --git a/ppdet/modeling/tests/test_mstest.py b/ppdet/modeling/tests/test_mstest.py new file mode 100644 index 0000000000000000000000000000000000000000..57d1d169fecbf235227aa388ae0d2e3d92330d0c --- /dev/null +++ b/ppdet/modeling/tests/test_mstest.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +from ppdet.core.workspace import load_config +from ppdet.engine import Trainer + +class TestMultiScaleInference(unittest.TestCase): + def setUp(self): + self.set_config() + + def set_config(self): + self.mstest_cfg_file = 'configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml' + + # test evaluation with multi scale test + def test_eval_mstest(self): + cfg = load_config(self.mstest_cfg_file) + trainer = Trainer(cfg, mode='eval') + + cfg.weights = 'https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_fpn_1x_coco.pdparams' + trainer.load_weights(cfg.weights) + + trainer.evaluate() + + # test inference with multi scale test + def test_infer_mstest(self): + cfg = load_config(self.mstest_cfg_file) + trainer = Trainer(cfg, mode='test') + + cfg.weights = 'https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_fpn_1x_coco.pdparams' + trainer.load_weights(cfg.weights) + tests_img_root = os.path.join(os.path.dirname(__file__), 'imgs') + + # input images to predict + imgs = ['coco2017_val2017_000000000139.jpg', 'coco2017_val2017_000000000724.jpg'] + imgs = [os.path.join(tests_img_root, img) for img in imgs] + trainer.predict(imgs, + draw_threshold=0.5, + output_dir='output', + save_txt=True) + + +if __name__ == '__main__': + unittest.main()