diff --git a/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml b/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..559d5f1fe9fdcbf42189383a69f9d1a056792cda --- /dev/null +++ b/configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml @@ -0,0 +1,22 @@ +_BASE_: [ + 'faster_rcnn_r34_fpn_1x_coco.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_pretrained.pdparams +weights: output/faster_rcnn_r34_fpn_multiscaletest_1x_coco/model_final + +EvalReader: + sample_transforms: + - Decode: {} +# - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900], use_flip: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} +# - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900], use_flip: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} \ No newline at end of file diff --git a/docs/tutorials/config_annotation/multi_scale_test_config.md b/docs/tutorials/config_annotation/multi_scale_test_config.md new file mode 100644 index 0000000000000000000000000000000000000000..5d553ca0da92aa404fab72c889699a49a7583bdd --- /dev/null +++ b/docs/tutorials/config_annotation/multi_scale_test_config.md @@ -0,0 +1,45 @@ +# Multi Scale Test Configuration + +Tags: Configuration + +--- +```yaml + +##################################### Multi scale test configuration ##################################### + +EvalReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} +``` + +--- + +Multi Scale Test is a TTA (Test Time Augmentation) method, it can improve object detection performance. + +The input image will be scaled into different scales, then model generated predictions (bboxes) at different scales, finally all the predictions will be combined to generate final prediction. (Here **NMS** is used to aggregate the predictions.) + +## _MultiscaleTestResize_ option + +`MultiscaleTestResize` option is used to enable multi scale test prediction. + +`origin_target_size: [800, 1333]` means the input image will be scaled to 800 (for short edge) and 1333 (max edge length cannot be greater than 1333) at first + +`target_size: [700 , 900]` property is used to specify different scales. + +It can be plugged into evaluation process or test (inference) process, by adding `MultiscaleTestResize` entry to `EvalReader.sample_transforms` or `TestReader.sample_transforms` + +--- + +###Note + +Now only CascadeRCNN, FasterRCNN and MaskRCNN are supported for multi scale testing. And batch size must be 1. \ No newline at end of file diff --git a/docs/tutorials/config_annotation/multi_scale_test_config_cn.md b/docs/tutorials/config_annotation/multi_scale_test_config_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..acf942c35ebf7c8eeb41304258759e0163d3dcf9 --- /dev/null +++ b/docs/tutorials/config_annotation/multi_scale_test_config_cn.md @@ -0,0 +1,45 @@ +# 多尺度测试的配置 + +标签: 配置 + +--- +```yaml + +##################################### 多尺度测试的配置 ##################################### + +EvalReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + +TestReader: + sample_transforms: + - Decode: {} + - MultiscaleTestResize: {origin_target_size: [800, 1333], target_size: [700 , 900]} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} +``` + +--- + +多尺度测试是一种TTA方法(测试时增强),可以用于提高目标检测的准确率 + +输入图像首先被缩放为不同尺度的图像,然后模型对这些不同尺度的图像进行预测,最后将这些不同尺度上的预测结果整合为最终预测结果。(这里使用了**NMS**来整合不同尺度的预测结果) + +## _MultiscaleTestResize_ 选项 + +`MultiscaleTestResize` 选项用于开启多尺度测试. + +`origin_target_size: [800, 1333]` 项代表输入图像首先缩放为短边为800,最长边不超过1333. + +`target_size: [700 , 900]` 项设置不同的预测尺度。 + +通过在`EvalReader.sample_transforms`或`TestReader.sample_transforms`中设置`MultiscaleTestResize`项,可以在评估过程或预测过程中开启多尺度测试。 + +--- + +###注意 + +目前多尺度测试只支持CascadeRCNN, FasterRCNN and MaskRCNN网络, 并且batch size需要是1. \ No newline at end of file diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index bacd53b75421bce4f2e95ba96ec56704ae4bd4be..e988e720108ea0c09b56bf3d642b3f795b540275 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import typing + try: from collections.abc import Sequence except Exception: @@ -58,7 +60,13 @@ class PadBatch(BaseOperator): """ coarsest_stride = self.pad_to_stride - max_shape = np.array([data['image'].shape for data in samples]).max( + # multi scale input is nested list + if isinstance(samples, typing.Sequence) and len(samples) > 0 and isinstance(samples[0], typing.Sequence): + inner_samples = samples[0] + else: + inner_samples = samples + + max_shape = np.array([data['image'].shape for data in inner_samples]).max( axis=0) if coarsest_stride > 0: max_shape[1] = int( @@ -66,7 +74,7 @@ class PadBatch(BaseOperator): max_shape[2] = int( np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) - for data in samples: + for data in inner_samples: im = data['image'] im_c, im_h, im_w = im.shape[:] padding_im = np.zeros( diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index edb02709ba3f7587a030c5ada3726a5ccb0325b0..353797309913c09244197e7b345de5c9b20b50df 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -22,6 +22,7 @@ import copy import time import numpy as np +import typing from PIL import Image import paddle @@ -428,7 +429,11 @@ class Trainer(object): for metric in self._metrics: metric.update(data, outs) - sample_num += data['im_id'].numpy().shape[0] + # multi-scale inputs: all inputs have same im_id + if isinstance(data, typing.Sequence): + sample_num += data[0]['im_id'].numpy().shape[0] + else: + sample_num += data['im_id'].numpy().shape[0] self._compose_callback.on_step_end(self.status) self.status['sample_num'] = sample_num @@ -471,7 +476,10 @@ class Trainer(object): outs = self.model(data) for key in ['im_shape', 'scale_factor', 'im_id']: - outs[key] = data[key] + if isinstance(data, typing.Sequence): + outs[key] = data[0][key] + else: + outs[key] = data[key] for key, value in outs.items(): if hasattr(value, 'numpy'): outs[key] = value.numpy() diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index 65b18efd82eb6f47fa9e7c2da550663693b00e59..ea9861319f17767e8aa5f914c15d95ceb53638d0 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -21,6 +21,7 @@ import sys import json import paddle import numpy as np +import typing from .map_utils import prune_zero_padding, DetectionMAP from .coco_utils import get_infer_results, cocoapi_eval @@ -97,7 +98,11 @@ class COCOMetric(Metric): for k, v in outputs.items(): outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v - im_id = inputs['im_id'] + # multi-scale inputs: all inputs have same im_id + if isinstance(inputs, typing.Sequence): + im_id = inputs[0]['im_id'] + else: + im_id = inputs['im_id'] outs['im_id'] = im_id.numpy() if isinstance(im_id, paddle.Tensor) else im_id diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py index 00085428d75bfc8d99533a22a08060354713e637..d914b4e98dc4f4a3aeed4cbce757694e1fb6de33 100644 --- a/ppdet/modeling/architectures/meta_arch.py +++ b/ppdet/modeling/architectures/meta_arch.py @@ -2,9 +2,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import paddle import paddle.nn as nn +import typing + from ppdet.core.workspace import register +from static.ppdet.utils.post_process import nms __all__ = ['BaseArch'] @@ -25,7 +29,53 @@ class BaseArch(nn.Layer): if self.training: out = self.get_loss() else: - out = self.get_pred() + inputs_list = [] + # multi-scale input + if not isinstance(inputs, typing.Sequence): + inputs_list.append(inputs) + else: + inputs_list.extend(inputs) + + outs = [] + for inp in inputs_list: + self.inputs = inp + outs.append(self.get_pred()) + + # multi-scale test + if len(outs)>1: + out = self.merge_multi_scale_predictions(outs) + else: + out = outs[0] + return out + + def merge_multi_scale_predictions(self, outs): + # default values for architectures not included in following list + num_classes = 80 + nms_threshold = 0.5 + keep_top_k = 100 + + if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'): + num_classes = self.bbox_head.num_classes + keep_top_k = self.bbox_post_process.nms.keep_top_k + nms_threshold = self.bbox_post_process.nms.nms_threshold + else: + raise Exception("Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now") + + final_boxes = [] + all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy() + for c in range(num_classes): + idxs = all_scale_outs[:, 0] == c + if np.count_nonzero(idxs) == 0: + continue + r = nms(all_scale_outs[idxs, 1:], nms_threshold) + final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1)) + out = np.concatenate(final_boxes) + out = np.concatenate(sorted(out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) + out = { + 'bbox': paddle.to_tensor(out), + 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ])) + } + return out def build_inputs(self, data, input_def): diff --git a/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..19023f718333c56c70776c79201dc03d742c1ed3 Binary files /dev/null and b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000139.jpg differ diff --git a/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2a17e0c6ee400dcba762c4d56dea03d7e124b9c5 Binary files /dev/null and b/ppdet/modeling/tests/imgs/coco2017_val2017_000000000724.jpg differ diff --git a/ppdet/modeling/tests/test_mstest.py b/ppdet/modeling/tests/test_mstest.py new file mode 100644 index 0000000000000000000000000000000000000000..57d1d169fecbf235227aa388ae0d2e3d92330d0c --- /dev/null +++ b/ppdet/modeling/tests/test_mstest.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +from ppdet.core.workspace import load_config +from ppdet.engine import Trainer + +class TestMultiScaleInference(unittest.TestCase): + def setUp(self): + self.set_config() + + def set_config(self): + self.mstest_cfg_file = 'configs/faster_rcnn/faster_rcnn_r34_fpn_multiscaletest_1x_coco.yml' + + # test evaluation with multi scale test + def test_eval_mstest(self): + cfg = load_config(self.mstest_cfg_file) + trainer = Trainer(cfg, mode='eval') + + cfg.weights = 'https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_fpn_1x_coco.pdparams' + trainer.load_weights(cfg.weights) + + trainer.evaluate() + + # test inference with multi scale test + def test_infer_mstest(self): + cfg = load_config(self.mstest_cfg_file) + trainer = Trainer(cfg, mode='test') + + cfg.weights = 'https://paddledet.bj.bcebos.com/models/faster_rcnn_r34_fpn_1x_coco.pdparams' + trainer.load_weights(cfg.weights) + tests_img_root = os.path.join(os.path.dirname(__file__), 'imgs') + + # input images to predict + imgs = ['coco2017_val2017_000000000139.jpg', 'coco2017_val2017_000000000724.jpg'] + imgs = [os.path.join(tests_img_root, img) for img in imgs] + trainer.predict(imgs, + draw_threshold=0.5, + output_dir='output', + save_txt=True) + + +if __name__ == '__main__': + unittest.main()