diff --git a/configs/datasets/sniper_coco_detection.yml b/configs/datasets/sniper_coco_detection.yml new file mode 100644 index 0000000000000000000000000000000000000000..b5cff989f5b58e79836e95efa2070c580e5edc44 --- /dev/null +++ b/configs/datasets/sniper_coco_detection.yml @@ -0,0 +1,47 @@ +metric: SNIPERCOCO +num_classes: 80 + +TrainDataset: + !SniperCOCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + allow_empty: true + is_trainset: true + image_target_sizes: [2000, 1000] + valid_box_ratio_ranges: [[-1, 0.1],[0.08, -1]] + chip_target_size: 512 + chip_target_stride: 200 + use_neg_chip: false + max_neg_num_per_im: 8 + + +EvalDataset: + !SniperCOCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + allow_empty: true + is_trainset: false + image_target_sizes: [2000, 1000] + valid_box_ratio_ranges: [[-1, 0.1], [0.08, -1]] + chip_target_size: 512 + chip_target_stride: 200 + max_per_img: -1 + nms_thresh: 0.5 + +TestDataset: + !SniperCOCODataSet + image_dir: val2017 + dataset_dir: dataset/coco + is_trainset: false + image_target_sizes: [2000, 1000] + valid_box_ratio_ranges: [[-1, 0.1],[0.08, -1]] + chip_target_size: 500 + chip_target_stride: 200 + max_per_img: -1 + nms_thresh: 0.5 + + diff --git a/configs/datasets/sniper_visdrone_detection.yml b/configs/datasets/sniper_visdrone_detection.yml new file mode 100644 index 0000000000000000000000000000000000000000..f6c12a9516b71026e3451e10b128aec1fbf96160 --- /dev/null +++ b/configs/datasets/sniper_visdrone_detection.yml @@ -0,0 +1,47 @@ +metric: SNIPERCOCO +num_classes: 9 + +TrainDataset: + !SniperCOCODataSet + image_dir: train + anno_path: annotations/train.json + dataset_dir: dataset/VisDrone2019_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + allow_empty: true + is_trainset: true + image_target_sizes: [8145, 2742] + valid_box_ratio_ranges: [[-1, 0.03142857142857144], [0.02333211853008726, -1]] + chip_target_size: 1536 + chip_target_stride: 1184 + use_neg_chip: false + max_neg_num_per_im: 8 + + +EvalDataset: + !SniperCOCODataSet + image_dir: val + anno_path: annotations/val.json + dataset_dir: dataset/VisDrone2019_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + allow_empty: true + is_trainset: false + image_target_sizes: [8145, 2742] + valid_box_ratio_ranges: [[-1, 0.03142857142857144], [0.02333211853008726, -1]] + chip_target_size: 1536 + chip_target_stride: 1184 + max_per_img: -1 + nms_thresh: 0.5 + +TestDataset: + !SniperCOCODataSet + image_dir: val + dataset_dir: dataset/VisDrone2019_coco + is_trainset: false + image_target_sizes: [8145, 2742] + valid_box_ratio_ranges: [[-1, 0.03142857142857144], [0.02333211853008726, -1]] + chip_target_size: 1536 + chip_target_stride: 1184 + max_per_img: -1 + nms_thresh: 0.5 + + diff --git a/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_visdrone.yml b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_visdrone.yml new file mode 100644 index 0000000000000000000000000000000000000000..e6f5abedb57c2bf30af264486889f69c47f36a73 --- /dev/null +++ b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_visdrone.yml @@ -0,0 +1,29 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/faster_rcnn_r50_fpn.yml', + '_base_/faster_fpn_reader.yml', +] +weights: output/faster_rcnn_r50_fpn_1x_coco_visdrone/model_final + + +metric: COCO +num_classes: 9 + +TrainDataset: + !COCODataSet + image_dir: train + anno_path: annotations/train.json + dataset_dir: dataset/VisDrone2019_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: val + anno_path: annotations/val.json + dataset_dir: dataset/VisDrone2019_coco + +TestDataset: + !ImageFolder + anno_path: annotations/val.json diff --git a/configs/ppyolo/ppyolo_r50vd_dcn_1x_visdrone.yml b/configs/ppyolo/ppyolo_r50vd_dcn_1x_visdrone.yml new file mode 100644 index 0000000000000000000000000000000000000000..0f02b35a0c02a6f1f0789a65b8d867a314d89e3d --- /dev/null +++ b/configs/ppyolo/ppyolo_r50vd_dcn_1x_visdrone.yml @@ -0,0 +1,54 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/ppyolo_r50vd_dcn.yml', + './_base_/optimizer_1x.yml', + './_base_/ppyolo_reader.yml', +] + +snapshot_epoch: 8 +use_ema: true +weights: output/ppyolo_r50vd_dcn_1x_visdrone_coco/model_final + +epoch: 192 + +LearningRate: + base_lr: 0.005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 153 + - 173 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + + +metric: COCO +num_classes: 9 + +TrainDataset: + !COCODataSet + image_dir: train + anno_path: annotations/train.json + dataset_dir: dataset/VisDrone2019_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: val + anno_path: annotations/val.json + dataset_dir: dataset/VisDrone2019_coco + +TestDataset: + !ImageFolder + anno_path: annotations/val.json \ No newline at end of file diff --git a/configs/sniper/README.md b/configs/sniper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0cf4bf7a2504f84d7444a0181beaaff42c1c99ef --- /dev/null +++ b/configs/sniper/README.md @@ -0,0 +1,51 @@ +English | [简体中文](README_cn.md) + +# SNIPER: Efficient Multi-Scale Training + +## Model Zoo + +| sniper | GPU number | images/GPU | Model | Network | schedulers | Box AP | download | config | +| :---------------- | :-------------------: | :------------------: | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| w/o sniper | 4 | 1 | ResNet-r50-FPN | Faster Rcnn | 1x | 23.3 | [faster_rcnn_r50_fpn_1x_visdrone](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_fpn_1x_visdrone.pdparams ) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/sniper/faster_rcnn_r50_fpn_1x_sniper_coco.yml) | +| w sniper | 4 | 1 | ResNet-r50-FPN | Faster Rcnn | 1x | 29.7 | [faster_rcnn_r50_fpn_1x_sniper_visdrone](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_fpn_1x_sniper_visdrone.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml) | + +## Getting Start +### 1. Training +a.optional: Run `tools/sniper_params_stats.py` to get image_target_sizes\valid_box_ratio_ranges\chip_target_size\chip_target_stride,and modify this params in configs/datasets/sniper_coco_detection.yml +```bash +python tools/sniper_params_stats.py FasterRCNN annotations/instances_train2017.json +``` +b.optional: trian detector to get negative proposals. +```bash +python -m paddle.distributed.launch --log_dir=./sniper/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml --save_proposals --proposals_path=./proposals.json &>sniper.log 2>&1 & +``` +c.train models +```bash +python -m paddle.distributed.launch --log_dir=./sniper/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml --eval &>sniper.log 2>&1 & +``` + +### 2. Evaluation +Evaluating SNIPER on custom dataset in single GPU with following commands: +```bash +# use saved checkpoint in training +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final +``` + +###3.Inference +Inference images in single GPU with following commands, use `--infer_img` to inference a single image and `--infer_dir` to inference all images in the directory. + +```bash +# inference single image +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final --infer_img=demo/P0861__1.0__1154___824.png + +# inference all images in the directory +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final --infer_dir=demo +``` + +## Citations +@misc{1805.09300, +Author = {Bharat Singh and Mahyar Najibi and Larry S. Davis}, +Title = {SNIPER: Efficient Multi-Scale Training}, +Year = {2018}, +Eprint = {arXiv:1805.09300}, +} diff --git a/configs/sniper/README_cn.md b/configs/sniper/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..2151b403258bedc6b33c87c177634e9738d19ea5 --- /dev/null +++ b/configs/sniper/README_cn.md @@ -0,0 +1,53 @@ +简体中文 | [English](README.md) + +# SNIPER: Efficient Multi-Scale Training + +## 模型库 +| 有无sniper | GPU个数 | 每张GPU图片个数 | 骨架网络 | 网络类型 | 学习率策略 | Box AP | 模型下载 | 配置文件 | +| :---------------- | :-------------------: | :------------------: | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| w/o sniper | 4 | 1 | ResNet-r50-FPN | Faster Rcnn | 1x | 23.3 | [faster_rcnn_r50_fpn_1x_visdrone](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_fpn_1x_visdrone.pdparams ) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/sniper/faster_rcnn_r50_fpn_1x_sniper_coco.yml) | +| w sniper | 4 | 1 | ResNet-r50-FPN | Faster Rcnn | 1x | 29.7 | [faster_rcnn_r50_fpn_1x_sniper_visdrone](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_fpn_1x_sniper_visdrone.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml) | + + + + +## 使用说明 +### 1. 训练 +a.可选:统计数据集信息,获得数据缩放尺度、有效框范围、chip尺度和步长等参数,修改configs/datasets/sniper_coco_detection.yml中对应参数 +```bash +python tools/sniper_params_stats.py FasterRCNN annotations/instances_train2017.json +``` +b.可选:训练检测器,生成负样本 +```bash +python -m paddle.distributed.launch --log_dir=./sniper/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml --save_proposals --proposals_path=./proposals.json &>sniper.log 2>&1 & +``` +c.训练模型 +```bash +python -m paddle.distributed.launch --log_dir=./sniper/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml --eval &>sniper.log 2>&1 & +``` + +### 2. 评估 +使用单GPU通过如下命令一键式评估模型在COCO val2017数据集效果 +```bash +# 使用训练保存的checkpoint +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final +``` + +###3.推理 +使用单GPU通过如下命令一键式推理图像,通过`--infer_img`指定图像路径,或通过`--infer_dir`指定目录并推理目录下所有图像 + +```bash +# 推理单张图像 +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final --infer_img=demo/P0861__1.0__1154___824.png + +# 推理目录下所有图像 +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml -o weights=output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final --infer_dir=demo +``` + +## Citations +@misc{1805.09300, +Author = {Bharat Singh and Mahyar Najibi and Larry S. Davis}, +Title = {SNIPER: Efficient Multi-Scale Training}, +Year = {2018}, +Eprint = {arXiv:1805.09300}, +} diff --git a/configs/sniper/_base_/faster_fpn_reader.yml b/configs/sniper/_base_/faster_fpn_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..363ca4664b9effb1317e6661732f99113b7d1bff --- /dev/null +++ b/configs/sniper/_base_/faster_fpn_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - SniperDecodeCrop: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: true + drop_last: true + collate_batch: false + + +EvalReader: + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/sniper/_base_/faster_reader.yml b/configs/sniper/_base_/faster_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..5c3b348024e8d48289db85706ddd6454f40c0815 --- /dev/null +++ b/configs/sniper/_base_/faster_reader.yml @@ -0,0 +1,41 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - SniperDecodeCrop: {} + - RandomResize: {target_size: [[800, 1333]], interp: 2, keep_ratio: True} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true + + +EvalReader: + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/sniper/_base_/ppyolo_reader.yml b/configs/sniper/_base_/ppyolo_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..f88e908c903b256bc08fa209f0f2368e3d58596b --- /dev/null +++ b/configs/sniper/_base_/ppyolo_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + inputs_def: + num_max_boxes: 50 + sample_transforms: + - SniperDecodeCrop: {} + - RandomDistort: {} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeBox: {} + - PadBox: {num_max_boxes: 50} + - BboxXYXY2XYWH: {} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + - Gt2YoloTarget: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8]} + batch_size: 8 + shuffle: true + drop_last: true + use_shared_memory: true + +EvalReader: + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {target_size: [608, 608], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_size: 8 + +TestReader: + inputs_def: + image_shape: [3, 608, 608] + sample_transforms: + - SniperDecodeCrop: {} + - Resize: {target_size: [608, 608], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_size: 1 diff --git a/configs/sniper/faster_rcnn_r50_fpn_1x_sniper_visdrone.yml b/configs/sniper/faster_rcnn_r50_fpn_1x_sniper_visdrone.yml new file mode 100644 index 0000000000000000000000000000000000000000..aaf22e12659ed4928fc6d7941c8b3d966ab0d82f --- /dev/null +++ b/configs/sniper/faster_rcnn_r50_fpn_1x_sniper_visdrone.yml @@ -0,0 +1,11 @@ +_BASE_: [ + '../datasets/sniper_visdrone_detection.yml', + '../runtime.yml', + '../faster_rcnn/_base_/faster_rcnn_r50_fpn.yml', + '../faster_rcnn/_base_/optimizer_1x.yml', + '_base_/faster_fpn_reader.yml', +] +weights: output/faster_rcnn_r50_1x_visdrone_coco/model_final +find_unused_parameters: true + + diff --git a/configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml b/configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..daebb5154525a4c04d4e42073ba7e2d6140215db --- /dev/null +++ b/configs/sniper/faster_rcnn_r50_fpn_2x_sniper_coco.yml @@ -0,0 +1,15 @@ +_BASE_: [ + 'faster_rcnn_r50_fpn_1x_sniper_coco.yml', +] +weights: output/faster_rcnn_r50_fpn_2x_sniper_coco/model_final + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 diff --git a/configs/sniper/faster_rcnn_r50_vd_1x_coco_sniper.yml b/configs/sniper/faster_rcnn_r50_vd_1x_coco_sniper.yml new file mode 100644 index 0000000000000000000000000000000000000000..fbf674dc72dfc479bc4fd8a31beed0edb67f4818 --- /dev/null +++ b/configs/sniper/faster_rcnn_r50_vd_1x_coco_sniper.yml @@ -0,0 +1,19 @@ +_BASE_: [ + '../datasets/sniper_coco_detection.yml', + '../runtime.yml', + '../faster_rcnn/_base_/optimizer_1x.yml', + '../faster_rcnn/_base_/faster_rcnn_r50.yml', + '_base_/faster_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_1x_coco.pdparams +weights: output/faster_rcnn_r50_vd_1x_coco/model_final + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [2] + num_stages: 3 diff --git a/configs/sniper/faster_rcnn_r50_vd_fpn_2x_sniper_coco.yml b/configs/sniper/faster_rcnn_r50_vd_fpn_2x_sniper_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..3667319cef8c1e1529d447b01c8d8ba7ca8b4168 --- /dev/null +++ b/configs/sniper/faster_rcnn_r50_vd_fpn_2x_sniper_coco.yml @@ -0,0 +1,25 @@ +_BASE_: [ + 'faster_rcnn_r50_fpn_1x_sniper_coco.yml', +] +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_pretrained.pdparams +weights: output/faster_rcnn_r50_vd_fpn_2x_sniper_coco/model_final + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [0,1,2,3] + num_stages: 4 + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 diff --git a/configs/sniper/ppyolo_r50vd_dcn_1x_sniper.yml b/configs/sniper/ppyolo_r50vd_dcn_1x_sniper.yml new file mode 100644 index 0000000000000000000000000000000000000000..da634b2b1fb4339a4d457cd7d12df8b0fb21af91 --- /dev/null +++ b/configs/sniper/ppyolo_r50vd_dcn_1x_sniper.yml @@ -0,0 +1,32 @@ +_BASE_: [ + '../datasets/sniper_coco_detection.yml', + '../runtime.yml', + '../ppyolo/_base_/ppyolo_r50vd_dcn.yml', + '../ppyolo/_base_/optimizer_1x.yml', + './_base_/ppyolo_reader.yml', +] + +snapshot_epoch: 8 +use_ema: true +weights: output/ppyolo_r50vd_dcn_1x_minicoco/model_final + + +LearningRate: + base_lr: 0.005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 153 + - 173 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/configs/sniper/ppyolo_r50vd_dcn_1x_sniper_visdrone.yml b/configs/sniper/ppyolo_r50vd_dcn_1x_sniper_visdrone.yml new file mode 100644 index 0000000000000000000000000000000000000000..58cce35056a4de23370f0e2f80ddabf410c262af --- /dev/null +++ b/configs/sniper/ppyolo_r50vd_dcn_1x_sniper_visdrone.yml @@ -0,0 +1,33 @@ +_BASE_: [ + '../datasets/sniper_visdrone_detection.yml', + '../runtime.yml', + '../ppyolo/_base_/ppyolo_r50vd_dcn.yml', + '../ppyolo/_base_/optimizer_1x.yml', + './_base_/ppyolo_reader.yml', +] + +snapshot_epoch: 8 +use_ema: true +weights: output/ppyolo_r50vd_dcn_1x_visdrone/model_final + + + +LearningRate: + base_lr: 0.005 + schedulers: + - !PiecewiseDecay + gamma: 0. + milestones: + - 153 + - 173 + - !LinearWarmup + start_factor: 0.1 + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/ppdet/data/crop_utils/__init__.py b/ppdet/data/crop_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61d5aa213694a29c4820ead6e2a74123c2df44e8 --- /dev/null +++ b/ppdet/data/crop_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/ppdet/data/crop_utils/annotation_cropper.py b/ppdet/data/crop_utils/annotation_cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..93a9a1f75fe46a15336553ea2689c78681780877 --- /dev/null +++ b/ppdet/data/crop_utils/annotation_cropper.py @@ -0,0 +1,542 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import random +import numpy as np +from copy import deepcopy +from typing import List, Tuple +from collections import defaultdict + +from .chip_box_utils import nms, transform_chip_boxes2image_boxes +from .chip_box_utils import find_chips_to_cover_overlaped_boxes +from .chip_box_utils import transform_chip_box +from .chip_box_utils import intersection_over_box + + +class AnnoCropper(object): + def __init__(self, image_target_sizes: List[int], + valid_box_ratio_ranges: List[List[float]], + chip_target_size: int, chip_target_stride: int, + use_neg_chip: bool = False, + max_neg_num_per_im: int = 8, + max_per_img: int = -1, + nms_thresh: int = 0.5 + ): + """ + Generate chips by chip_target_size and chip_target_stride. + These two parameters just like kernel_size and stride in cnn. + + Each image has its raw size. After resizing, then get its target size. + The resizing scale = target_size / raw_size. + So are chips of the image. + box_ratio = box_raw_size / image_raw_size = box_target_size / image_target_size + The 'size' above mentioned is the size of long-side of image, box or chip. + + :param image_target_sizes: [2000, 1000] + :param valid_box_ratio_ranges: [[-1, 0.1],[0.08, -1]] + :param chip_target_size: 500 + :param chip_target_stride: 200 + """ + self.target_sizes = image_target_sizes + self.valid_box_ratio_ranges = valid_box_ratio_ranges + assert len(self.target_sizes) == len(self.valid_box_ratio_ranges) + self.scale_num = len(self.target_sizes) + self.chip_target_size = chip_target_size # is target size + self.chip_target_stride = chip_target_stride # is target stride + self.use_neg_chip = use_neg_chip + self.max_neg_num_per_im = max_neg_num_per_im + self.max_per_img = max_per_img + self.nms_thresh = nms_thresh + + def crop_anno_records(self, records: List[dict]): + """ + The main logic: + # foreach record(image): + # foreach scale: + # 1 generate chips by chip size and stride for each scale + # 2 get pos chips + # - validate boxes: current scale; h,w >= 1 + # - find pos chips greedily by valid gt boxes in each scale + # - for every valid gt box, find its corresponding pos chips in each scale + # 3 get neg chips + # - If given proposals, find neg boxes in them which are not in pos chips + # - If got neg boxes in last step, we find neg chips and assign neg boxes to neg chips such as 2. + # 4 sample neg chips if too much each image + # transform this image-scale annotations to chips(pos chips&neg chips) annotations + + :param records, standard coco_record but with extra key `proposals`(Px4), which are predicted by stage1 + model and maybe have neg boxes in them. + :return: new_records, list of dict like + { + 'im_file': 'fake_image1.jpg', + 'im_id': np.array([1]), # new _global_chip_id as im_id + 'h': h, # chip height + 'w': w, # chip width + 'is_crowd': is_crowd, # Nx1 -> Mx1 + 'gt_class': gt_class, # Nx1 -> Mx1 + 'gt_bbox': gt_bbox, # Nx4 -> Mx4, 4 represents [x1,y1,x2,y2] + 'gt_poly': gt_poly, # [None]xN -> [None]xM + 'chip': [x1, y1, x2, y2] # added + } + + Attention: + ------------------------------>x + | + | (x1,y1)------ + | | | + | | | + | | | + | | | + | | | + | ---------- + | (x2,y2) + | + ↓ + y + + If we use [x1, y1, x2, y2] to represent boxes or chips, + (x1,y1) is the left-top point which is in the box, + but (x2,y2) is the right-bottom point which is not in the box. + So x1 in [0, w-1], x2 in [1, w], y1 in [0, h-1], y2 in [1,h]. + And you can use x2-x1 to get width, and you can use image[y1:y2, x1:x2] to get the box area. + """ + + self.chip_records = [] + self._global_chip_id = 1 + for r in records: + self._cur_im_pos_chips = [] # element: (chip, boxes_idx), chip is [x1, y1, x2, y2], boxes_ids is List[int] + self._cur_im_neg_chips = [] # element: (chip, neg_box_num) + for scale_i in range(self.scale_num): + self._get_current_scale_parameters(scale_i, r) + + # Cx4 + chips = self._create_chips(r['h'], r['w'], self._cur_scale) + + # # dict: chipid->[box_id, ...] + pos_chip2boxes_idx = self._get_valid_boxes_and_pos_chips(r['gt_bbox'], chips) + + # dict: chipid->neg_box_num + neg_chip2box_num = self._get_neg_boxes_and_chips(chips, list(pos_chip2boxes_idx.keys()), r.get('proposals', None)) + + self._add_to_cur_im_chips(chips, pos_chip2boxes_idx, neg_chip2box_num) + + cur_image_records = self._trans_all_chips2annotations(r) + self.chip_records.extend(cur_image_records) + return self.chip_records + + def _add_to_cur_im_chips(self, chips, pos_chip2boxes_idx, neg_chip2box_num): + for pos_chipid, boxes_idx in pos_chip2boxes_idx.items(): + chip = np.array(chips[pos_chipid]) # copy chips slice + self._cur_im_pos_chips.append((chip, boxes_idx)) + + if neg_chip2box_num is None: + return + + for neg_chipid, neg_box_num in neg_chip2box_num.items(): + chip = np.array(chips[neg_chipid]) + self._cur_im_neg_chips.append((chip, neg_box_num)) + + def _trans_all_chips2annotations(self, r): + gt_bbox = r['gt_bbox'] + im_file = r['im_file'] + is_crowd = r['is_crowd'] + gt_class = r['gt_class'] + # gt_poly = r['gt_poly'] # [None]xN + # remaining keys: im_id, h, w + chip_records = self._trans_pos_chips2annotations(im_file, gt_bbox, is_crowd, gt_class) + + if not self.use_neg_chip: + return chip_records + + sampled_neg_chips = self._sample_neg_chips() + neg_chip_records = self._trans_neg_chips2annotations(im_file, sampled_neg_chips) + chip_records.extend(neg_chip_records) + return chip_records + + def _trans_pos_chips2annotations(self, im_file, gt_bbox, is_crowd, gt_class): + chip_records = [] + for chip, boxes_idx in self._cur_im_pos_chips: + chip_bbox, final_boxes_idx = transform_chip_box(gt_bbox, boxes_idx, chip) + x1, y1, x2, y2 = chip + chip_h = y2 - y1 + chip_w = x2 - x1 + rec = { + 'im_file': im_file, + 'im_id': np.array([self._global_chip_id]), + 'h': chip_h, + 'w': chip_w, + 'gt_bbox': chip_bbox, + 'is_crowd': is_crowd[final_boxes_idx].copy(), + 'gt_class': gt_class[final_boxes_idx].copy(), + # 'gt_poly': [None] * len(final_boxes_idx), + 'chip': chip + } + self._global_chip_id += 1 + chip_records.append(rec) + return chip_records + + def _sample_neg_chips(self): + pos_num = len(self._cur_im_pos_chips) + neg_num = len(self._cur_im_neg_chips) + sample_num = min(pos_num + 2, self.max_neg_num_per_im) + assert sample_num >= 1 + if neg_num <= sample_num: + return self._cur_im_neg_chips + + candidate_num = int(sample_num * 1.5) + candidate_neg_chips = sorted(self._cur_im_neg_chips, key=lambda x: -x[1])[:candidate_num] + random.shuffle(candidate_neg_chips) + sampled_neg_chips = candidate_neg_chips[:sample_num] + return sampled_neg_chips + + def _trans_neg_chips2annotations(self, im_file: str, sampled_neg_chips: List[Tuple]): + chip_records = [] + for chip, neg_box_num in sampled_neg_chips: + x1, y1, x2, y2 = chip + chip_h = y2 - y1 + chip_w = x2 - x1 + rec = { + 'im_file': im_file, + 'im_id': np.array([self._global_chip_id]), + 'h': chip_h, + 'w': chip_w, + 'gt_bbox': np.zeros((0, 4), dtype=np.float32), + 'is_crowd': np.zeros((0, 1), dtype=np.int32), + 'gt_class': np.zeros((0, 1), dtype=np.int32), + # 'gt_poly': [], + 'chip': chip + } + self._global_chip_id += 1 + chip_records.append(rec) + return chip_records + + def _get_current_scale_parameters(self, scale_i, r): + im_size = max(r['h'], r['w']) + im_target_size = self.target_sizes[scale_i] + self._cur_im_size, self._cur_im_target_size = im_size, im_target_size + self._cur_scale = self._get_current_scale(im_target_size, im_size) + self._cur_valid_ratio_range = self.valid_box_ratio_ranges[scale_i] + + def _get_current_scale(self, im_target_size, im_size): + return im_target_size / im_size + + def _create_chips(self, h: int, w: int, scale: float): + """ + Generate chips by chip_target_size and chip_target_stride. + These two parameters just like kernel_size and stride in cnn. + :return: chips, Cx4, xy in raw size dimension + """ + chip_size = self.chip_target_size # omit target for simplicity + stride = self.chip_target_stride + width = int(scale * w) + height = int(scale * h) + min_chip_location_diff = 20 # in target size + + assert chip_size >= stride + chip_overlap = chip_size - stride + if (width - chip_overlap) % stride > min_chip_location_diff: # 不能被stride整除的部分比较大,则保留 + w_steps = max(1, int(math.ceil((width - chip_overlap) / stride))) + else: # 不能被stride整除的部分比较小,则丢弃 + w_steps = max(1, int(math.floor((width - chip_overlap) / stride))) + if (height - chip_overlap) % stride > min_chip_location_diff: + h_steps = max(1, int(math.ceil((height - chip_overlap) / stride))) + else: + h_steps = max(1, int(math.floor((height - chip_overlap) / stride))) + + chips = list() + for j in range(h_steps): + for i in range(w_steps): + x1 = i * stride + y1 = j * stride + x2 = min(x1 + chip_size, width) + y2 = min(y1 + chip_size, height) + chips.append([x1, y1, x2, y2]) + + # check chip size + for item in chips: + if item[2] - item[0] > chip_size * 1.1 or item[3] - item[1] > chip_size * 1.1: + raise ValueError(item) + chips = np.array(chips, dtype=np.float) + + raw_size_chips = chips / scale + return raw_size_chips + + def _get_valid_boxes_and_pos_chips(self, gt_bbox, chips): + valid_ratio_range = self._cur_valid_ratio_range + im_size = self._cur_im_size + scale = self._cur_scale + # Nx4 N + valid_boxes, valid_boxes_idx = self._validate_boxes(valid_ratio_range, im_size, gt_bbox, scale) + # dict: chipid->[box_id, ...] + pos_chip2boxes_idx = self._find_pos_chips(chips, valid_boxes, valid_boxes_idx) + return pos_chip2boxes_idx + + def _validate_boxes(self, valid_ratio_range: List[float], + im_size: int, + gt_boxes: 'np.array of Nx4', + scale: float): + """ + :return: valid_boxes: Nx4, valid_boxes_idx: N + """ + ws = (gt_boxes[:, 2] - gt_boxes[:, 0]).astype(np.int32) + hs = (gt_boxes[:, 3] - gt_boxes[:, 1]).astype(np.int32) + maxs = np.maximum(ws, hs) + box_ratio = maxs / im_size + mins = np.minimum(ws, hs) + target_mins = mins * scale + + low = valid_ratio_range[0] if valid_ratio_range[0] > 0 else 0 + high = valid_ratio_range[1] if valid_ratio_range[1] > 0 else np.finfo(np.float).max + + valid_boxes_idx = np.nonzero((low <= box_ratio) & (box_ratio < high) & (target_mins >= 2))[0] + valid_boxes = gt_boxes[valid_boxes_idx] + return valid_boxes, valid_boxes_idx + + def _find_pos_chips(self, chips: 'Cx4', valid_boxes: 'Bx4', valid_boxes_idx: 'B'): + """ + :return: pos_chip2boxes_idx, dict: chipid->[box_id, ...] + """ + iob = intersection_over_box(chips, valid_boxes) # overlap, CxB + + iob_threshold_to_find_chips = 1. + pos_chip_ids, _ = self._find_chips_to_cover_overlaped_boxes(iob, iob_threshold_to_find_chips) + pos_chip_ids = set(pos_chip_ids) + + iob_threshold_to_assign_box = 0.5 + pos_chip2boxes_idx = self._assign_boxes_to_pos_chips( + iob, iob_threshold_to_assign_box, pos_chip_ids, valid_boxes_idx) + return pos_chip2boxes_idx + + def _find_chips_to_cover_overlaped_boxes(self, iob, overlap_threshold): + return find_chips_to_cover_overlaped_boxes(iob, overlap_threshold) + + def _assign_boxes_to_pos_chips(self, iob, overlap_threshold, pos_chip_ids, valid_boxes_idx): + chip_ids, box_ids = np.nonzero(iob >= overlap_threshold) + pos_chip2boxes_idx = defaultdict(list) + for chip_id, box_id in zip(chip_ids, box_ids): + if chip_id not in pos_chip_ids: + continue + raw_gt_box_idx = valid_boxes_idx[box_id] + pos_chip2boxes_idx[chip_id].append(raw_gt_box_idx) + return pos_chip2boxes_idx + + def _get_neg_boxes_and_chips(self, chips: 'Cx4', pos_chip_ids: 'D', proposals: 'Px4'): + """ + :param chips: + :param pos_chip_ids: + :param proposals: + :return: neg_chip2box_num, None or dict: chipid->neg_box_num + """ + if not self.use_neg_chip: + return None + + # train proposals maybe None + if proposals is None or len(proposals) < 1: + return None + + valid_ratio_range = self._cur_valid_ratio_range + im_size = self._cur_im_size + scale = self._cur_scale + + valid_props, _ = self._validate_boxes(valid_ratio_range, im_size, proposals, scale) + neg_boxes = self._find_neg_boxes(chips, pos_chip_ids, valid_props) + neg_chip2box_num = self._find_neg_chips(chips, pos_chip_ids, neg_boxes) + return neg_chip2box_num + + def _find_neg_boxes(self, chips: 'Cx4', pos_chip_ids: 'D', valid_props: 'Px4'): + """ + :return: neg_boxes: Nx4 + """ + if len(pos_chip_ids) == 0: + return valid_props + + pos_chips = chips[pos_chip_ids] + iob = intersection_over_box(pos_chips, valid_props) + overlap_per_prop = np.max(iob, axis=0) + non_overlap_props_idx = overlap_per_prop < 0.5 + neg_boxes = valid_props[non_overlap_props_idx] + return neg_boxes + + def _find_neg_chips(self, chips: 'Cx4', pos_chip_ids: 'D', neg_boxes: 'Nx4'): + """ + :return: neg_chip2box_num, dict: chipid->neg_box_num + """ + neg_chip_ids = np.setdiff1d(np.arange(len(chips)), pos_chip_ids) + neg_chips = chips[neg_chip_ids] + + iob = intersection_over_box(neg_chips, neg_boxes) + iob_threshold_to_find_chips = 0.7 + chosen_neg_chip_ids, chip_id2overlap_box_num = \ + self._find_chips_to_cover_overlaped_boxes(iob, iob_threshold_to_find_chips) + + neg_chipid2box_num = {} + for cid in chosen_neg_chip_ids: + box_num = chip_id2overlap_box_num[cid] + raw_chip_id = neg_chip_ids[cid] + neg_chipid2box_num[raw_chip_id] = box_num + return neg_chipid2box_num + + def crop_infer_anno_records(self, records: List[dict]): + """ + transform image record to chips record + :param records: + :return: new_records, list of dict like + { + 'im_file': 'fake_image1.jpg', + 'im_id': np.array([1]), # new _global_chip_id as im_id + 'h': h, # chip height + 'w': w, # chip width + 'chip': [x1, y1, x2, y2] # added + 'ori_im_h': ori_im_h # added, origin image height + 'ori_im_w': ori_im_w # added, origin image width + 'scale_i': 0 # added, + } + """ + self.chip_records = [] + self._global_chip_id = 1 # im_id start from 1 + self._global_chip_id2img_id = {} + + for r in records: + for scale_i in range(self.scale_num): + self._get_current_scale_parameters(scale_i, r) + # Cx4 + chips = self._create_chips(r['h'], r['w'], self._cur_scale) + cur_img_chip_record = self._get_chips_records(r, chips, scale_i) + self.chip_records.extend(cur_img_chip_record) + + return self.chip_records + + def _get_chips_records(self, rec, chips, scale_i): + cur_img_chip_records = [] + ori_im_h = rec["h"] + ori_im_w = rec["w"] + im_file = rec["im_file"] + ori_im_id = rec["im_id"] + for id, chip in enumerate(chips): + chip_rec = {} + x1, y1, x2, y2 = chip + chip_h = y2 - y1 + chip_w = x2 - x1 + chip_rec["im_file"] = im_file + chip_rec["im_id"] = self._global_chip_id + chip_rec["h"] = chip_h + chip_rec["w"] = chip_w + chip_rec["chip"] = chip + chip_rec["ori_im_h"] = ori_im_h + chip_rec["ori_im_w"] = ori_im_w + chip_rec["scale_i"] = scale_i + + self._global_chip_id2img_id[self._global_chip_id] = int(ori_im_id) + self._global_chip_id += 1 + cur_img_chip_records.append(chip_rec) + + return cur_img_chip_records + + def aggregate_chips_detections(self, results, records=None): + """ + # 1. transform chip dets to image dets + # 2. nms boxes per image; + # 3. format output results + :param results: + :param roidb: + :return: + """ + results = deepcopy(results) + records = records if records else self.chip_records + img_id2bbox = self._transform_chip2image_bboxes(results, records) + nms_img_id2bbox = self._nms_dets(img_id2bbox) + aggregate_results = self._reformat_results(nms_img_id2bbox) + return aggregate_results + + def _transform_chip2image_bboxes(self, results, records): + # 1. Transform chip dets to image dets; + # 2. Filter valid range; + # 3. Reformat and Aggregate chip dets to Get scale_cls_dets + img_id2bbox = defaultdict(list) + for result in results: + bbox_locs = result['bbox'] + bbox_nums = result['bbox_num'] + if len(bbox_locs) == 1 and bbox_locs[0][0] == -1: # current batch has no detections + # bbox_locs = array([[-1.]], dtype=float32); bbox_nums = [[1]] + # MultiClassNMS output: If there is no detected boxes for all images, lod will be set to {1} and Out only contains one value which is -1. + continue + im_ids = result['im_id'] # replace with range(len(bbox_nums)) + + last_bbox_num = 0 + for idx, im_id in enumerate(im_ids): + + cur_bbox_len = bbox_nums[idx] + bboxes = bbox_locs[last_bbox_num: last_bbox_num + cur_bbox_len] + last_bbox_num += cur_bbox_len + # box: [num_id, score, xmin, ymin, xmax, ymax] + if len(bboxes) == 0: # current image has no detections + continue + + chip_rec = records[int(im_id) - 1] # im_id starts from 1, type is np.int64 + image_size = max(chip_rec["ori_im_h"], chip_rec["ori_im_w"]) + + bboxes = transform_chip_boxes2image_boxes(bboxes, chip_rec["chip"], chip_rec["ori_im_h"], chip_rec["ori_im_w"]) + + scale_i = chip_rec["scale_i"] + cur_scale = self._get_current_scale(self.target_sizes[scale_i], image_size) + _, valid_boxes_idx = self._validate_boxes(self.valid_box_ratio_ranges[scale_i], image_size, + bboxes[:, 2:], cur_scale) + ori_img_id = self._global_chip_id2img_id[int(im_id)] + + img_id2bbox[ori_img_id].append(bboxes[valid_boxes_idx]) + + return img_id2bbox + + def _nms_dets(self, img_id2bbox): + # 1. NMS on each image-class + # 2. Limit number of detections to MAX_PER_IMAGE if requested + max_per_img = self.max_per_img + nms_thresh = self.nms_thresh + + for img_id in img_id2bbox: + box = img_id2bbox[img_id] # list of np.array of shape [N, 6], 6 is [label, score, x1, y1, x2, y2] + box = np.concatenate(box, axis=0) + nms_dets = nms(box, nms_thresh) + if max_per_img > 0: + if len(nms_dets) > max_per_img: + keep = np.argsort(-nms_dets[:, 1])[:max_per_img] + nms_dets = nms_dets[keep] + + img_id2bbox[img_id] = nms_dets + + return img_id2bbox + + def _reformat_results(self, img_id2bbox): + """reformat results""" + im_ids = img_id2bbox.keys() + results = [] + for img_id in im_ids: # output by original im_id order + if len(img_id2bbox[img_id]) == 0: + bbox = np.array([[-1., 0., 0., 0., 0., 0.]]) # edge case: no detections + bbox_num = np.array([0]) + else: + # np.array of shape [N, 6], 6 is [label, score, x1, y1, x2, y2] + bbox = img_id2bbox[img_id] + bbox_num = np.array([len(bbox)]) + res = dict( + im_id=np.array([[img_id]]), + bbox=bbox, + bbox_num=bbox_num + ) + results.append(res) + return results + + diff --git a/ppdet/data/crop_utils/chip_box_utils.py b/ppdet/data/crop_utils/chip_box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e81a1654a269e2a28837bc884dc75c21d98ee4 --- /dev/null +++ b/ppdet/data/crop_utils/chip_box_utils.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def bbox_area(boxes): + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +def intersection_over_box(chips, boxes): + """ + intersection area over box area + :param chips: C + :param boxes: B + :return: iob, CxB + """ + M = chips.shape[0] + N = boxes.shape[0] + if M * N == 0: + return np.zeros([M, N], dtype='float32') + + box_area = bbox_area(boxes) # B + + inter_x2y2 = np.minimum(np.expand_dims(chips, 1)[:, :, 2:], boxes[:, 2:]) # CxBX2 + inter_x1y1 = np.maximum(np.expand_dims(chips, 1)[:, :, :2], boxes[:, :2]) # CxBx2 + inter_wh = inter_x2y2 - inter_x1y1 + inter_wh = np.clip(inter_wh, a_min=0, a_max=None) + inter_area = inter_wh[:, :, 0] * inter_wh[:, :, 1] # CxB + + iob = inter_area / np.expand_dims(box_area, 0) + return iob + + +def clip_boxes(boxes, im_shape): + """ + Clip boxes to image boundaries. + :param boxes: [N, 4] + :param im_shape: tuple of 2, [h, w] + :return: [N, 4] + """ + # x1 >= 0 + boxes[:, 0] = np.clip(boxes[:, 0], 0, im_shape[1] - 1) + # y1 >= 0 + boxes[:, 1] = np.clip(boxes[:, 1], 0, im_shape[0] - 1) + # x2 < im_shape[1] + boxes[:, 2] = np.clip(boxes[:, 2], 1, im_shape[1]) + # y2 < im_shape[0] + boxes[:, 3] = np.clip(boxes[:, 3], 1, im_shape[0]) + return boxes + + +def transform_chip_box(gt_bbox: 'Gx4', boxes_idx: 'B', chip: '4'): + boxes_idx = np.array(boxes_idx) + cur_gt_bbox = gt_bbox[boxes_idx].copy() # Bx4 + x1, y1, x2, y2 = chip + cur_gt_bbox[:, 0] -= x1 + cur_gt_bbox[:, 1] -= y1 + cur_gt_bbox[:, 2] -= x1 + cur_gt_bbox[:, 3] -= y1 + h = y2 - y1 + w = x2 - x1 + cur_gt_bbox = clip_boxes(cur_gt_bbox, (h, w)) + ws = (cur_gt_bbox[:, 2] - cur_gt_bbox[:, 0]).astype(np.int32) + hs = (cur_gt_bbox[:, 3] - cur_gt_bbox[:, 1]).astype(np.int32) + valid_idx = (ws >= 2) & (hs >= 2) + return cur_gt_bbox[valid_idx], boxes_idx[valid_idx] + + +def find_chips_to_cover_overlaped_boxes(iob, overlap_threshold): + chip_ids, box_ids = np.nonzero(iob >= overlap_threshold) + chip_id2overlap_box_num = np.bincount(chip_ids) # 1d array + chip_id2overlap_box_num = np.pad(chip_id2overlap_box_num, (0, len(iob) - len(chip_id2overlap_box_num)), + constant_values=0) + + chosen_chip_ids = [] + while len(box_ids) > 0: + value_counts = np.bincount(chip_ids) # 1d array + max_count_chip_id = np.argmax(value_counts) + assert max_count_chip_id not in chosen_chip_ids + chosen_chip_ids.append(max_count_chip_id) + + box_ids_in_cur_chip = box_ids[chip_ids == max_count_chip_id] + ids_not_in_cur_boxes_mask = np.logical_not(np.isin(box_ids, box_ids_in_cur_chip)) + chip_ids = chip_ids[ids_not_in_cur_boxes_mask] + box_ids = box_ids[ids_not_in_cur_boxes_mask] + return chosen_chip_ids, chip_id2overlap_box_num + + +def transform_chip_boxes2image_boxes(chip_boxes, chip, img_h, img_w): + chip_boxes = np.array(sorted(chip_boxes, key=lambda item: -item[1])) + xmin, ymin, _, _ = chip + # Transform to origin image loc + chip_boxes[:, 2] += xmin + chip_boxes[:, 4] += xmin + chip_boxes[:, 3] += ymin + chip_boxes[:, 5] += ymin + chip_boxes = clip_boxes(chip_boxes, (img_h, img_w)) + return chip_boxes + + +def nms(dets, thresh): + """Apply classic DPM-style greedy NMS.""" + if dets.shape[0] == 0: + return dets[[], :] + scores = dets[:, 1] + x1 = dets[:, 2] + y1 = dets[:, 3] + x2 = dets[:, 4] + y2 = dets[:, 5] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + ndets = dets.shape[0] + suppressed = np.zeros((ndets), dtype=np.int) + + # nominal indices + # _i, _j + # sorted indices + # i, j + # temp variables for box i's (the box currently under consideration) + # ix1, iy1, ix2, iy2, iarea + + # variables for computing overlap with box j (lower scoring box) + # xx1, yy1, xx2, yy2 + # w, h + # inter, ovr + + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + keep = np.where(suppressed == 0)[0] + dets = dets[keep, :] + return dets diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 5e9d59718ce57ad8a5dfb38cd8adb90d3eccace3..3854d3d2530b032b3c84d1ab5f2e01ea963c5c70 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -18,6 +18,7 @@ from . import widerface from . import category from . import keypoint_coco from . import mot +from . import sniper_coco from .coco import * from .voc import * @@ -25,3 +26,4 @@ from .widerface import * from .category import * from .keypoint_coco import * from .mot import * +from .sniper_coco import SniperCOCODataSet diff --git a/ppdet/data/source/category.py b/ppdet/data/source/category.py index 4f85f5260ea0b836339009ef2e1e55841630e1cb..757b10a3222cfbb2f33001a5eac3e4ad7c3c66e9 100644 --- a/ppdet/data/source/category.py +++ b/ppdet/data/source/category.py @@ -39,7 +39,7 @@ def get_categories(metric_type, anno_file=None, arch=None): if arch == 'keypoint_arch': return (None, {'id': 'keypoint'}) - if metric_type.lower() == 'coco' or metric_type.lower() == 'rbox': + if metric_type.lower() == 'coco' or metric_type.lower() == 'rbox' or metric_type.lower() == 'snipercoco': if anno_file and os.path.isfile(anno_file): # lazy import pycocotools here from pycocotools.coco import COCO diff --git a/ppdet/data/source/sniper_coco.py b/ppdet/data/source/sniper_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..1b07e7a31d999d137965c4860a4d8085d0b91465 --- /dev/null +++ b/ppdet/data/source/sniper_coco.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import json +import copy +import numpy as np + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from ppdet.core.workspace import register, serializable +from ppdet.data.crop_utils.annotation_cropper import AnnoCropper +from .coco import COCODataSet +from .dataset import _make_dataset, _is_valid_file +from ppdet.utils.logger import setup_logger + +logger = setup_logger('sniper_coco_dataset') + + +@register +@serializable +class SniperCOCODataSet(COCODataSet): + """SniperCOCODataSet""" + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + proposals_file=None, + data_fields=['image'], + sample_num=-1, + load_crowd=False, + allow_empty=True, + empty_ratio=1., + is_trainset=True, + image_target_sizes=[2000, 1000], + valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]], + chip_target_size=500, + chip_target_stride=200, + use_neg_chip=False, + max_neg_num_per_im=8, + max_per_img=-1, + nms_thresh=0.5): + super(SniperCOCODataSet, self).__init__( + dataset_dir=dataset_dir, + image_dir=image_dir, + anno_path=anno_path, + data_fields=data_fields, + sample_num=sample_num, + load_crowd=load_crowd, + allow_empty=allow_empty, + empty_ratio=empty_ratio + ) + self.proposals_file = proposals_file + self.proposals = None + self.anno_cropper = None + self.is_trainset = is_trainset + self.image_target_sizes = image_target_sizes + self.valid_box_ratio_ranges = valid_box_ratio_ranges + self.chip_target_size = chip_target_size + self.chip_target_stride = chip_target_stride + self.use_neg_chip = use_neg_chip + self.max_neg_num_per_im = max_neg_num_per_im + self.max_per_img = max_per_img + self.nms_thresh = nms_thresh + + + def parse_dataset(self): + if not hasattr(self, "roidbs"): + super(SniperCOCODataSet, self).parse_dataset() + if self.is_trainset: + self._parse_proposals() + self._merge_anno_proposals() + self.ori_roidbs = copy.deepcopy(self.roidbs) + self.init_anno_cropper() + self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset) + + def set_proposals_file(self, file_path): + self.proposals_file = file_path + + def init_anno_cropper(self): + logger.info("Init AnnoCropper...") + self.anno_cropper = AnnoCropper( + image_target_sizes=self.image_target_sizes, + valid_box_ratio_ranges=self.valid_box_ratio_ranges, + chip_target_size=self.chip_target_size, + chip_target_stride=self.chip_target_stride, + use_neg_chip=self.use_neg_chip, + max_neg_num_per_im=self.max_neg_num_per_im, + max_per_img=self.max_per_img, + nms_thresh=self.nms_thresh + ) + + def generate_chips_roidbs(self, roidbs, is_trainset): + if is_trainset: + roidbs = self.anno_cropper.crop_anno_records(roidbs) + else: + roidbs = self.anno_cropper.crop_infer_anno_records(roidbs) + return roidbs + + def _parse_proposals(self): + if self.proposals_file: + self.proposals = {} + logger.info("Parse proposals file:{}".format(self.proposals_file)) + with open(self.proposals_file, 'r') as f: + proposals = json.load(f) + for prop in proposals: + image_id = prop["image_id"] + if image_id not in self.proposals: + self.proposals[image_id] = [] + x, y, w, h = prop["bbox"] + self.proposals[image_id].append([x, y, x + w, y + h]) + + def _merge_anno_proposals(self): + assert self.roidbs + if self.proposals and len(self.proposals.keys()) > 0: + logger.info("merge proposals to annos") + for id, record in enumerate(self.roidbs): + image_id = int(record["im_id"]) + if image_id not in self.proposals.keys(): + logger.info("image id :{} no proposals".format(image_id)) + record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32) + self.roidbs[id] = record + + def get_ori_roidbs(self): + if not hasattr(self, "ori_roidbs"): + return None + return self.ori_roidbs + + def get_roidbs(self): + if not hasattr(self, "roidbs"): + self.parse_dataset() + return self.roidbs + + def set_roidbs(self, roidbs): + self.roidbs = roidbs + + def check_or_download_dataset(self): + return + + def _parse(self): + image_dir = self.image_dir + if not isinstance(image_dir, Sequence): + image_dir = [image_dir] + images = [] + for im_dir in image_dir: + if os.path.isdir(im_dir): + im_dir = os.path.join(self.dataset_dir, im_dir) + images.extend(_make_dataset(im_dir)) + elif os.path.isfile(im_dir) and _is_valid_file(im_dir): + images.append(im_dir) + return images + + def _load_images(self): + images = self._parse() + ct = 0 + records = [] + for image in images: + assert image != '' and os.path.isfile(image), \ + "Image {} not found".format(image) + if self.sample_num > 0 and ct >= self.sample_num: + break + im = cv2.imread(image) + h, w, c = im.shape + rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w} + self._imid2path[ct] = image + ct += 1 + records.append(rec) + assert len(records) > 0, "No image file found" + return records + + def get_imid2path(self): + return self._imid2path + + def set_images(self, images): + self._imid2path = {} + self.image_dir = images + self.roidbs = self._load_images() + diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index c9c93196643c629c2f516aa30b3c0f48f4128411..ae2d799701cdc448202ce1ab2bef97358970666e 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -233,6 +233,41 @@ class DecodeCache(BaseOperator): MUTEX.release() +@register_op +class SniperDecodeCrop(BaseOperator): + def __init__(self): + super(SniperDecodeCrop, self).__init__() + + def __call__(self, sample, context=None): + if 'image' not in sample: + with open(sample['im_file'], 'rb') as f: + sample['image'] = f.read() + sample.pop('im_file') + + im = sample['image'] + data = np.frombuffer(im, dtype='uint8') + im = cv2.imdecode(data, cv2.IMREAD_COLOR) # BGR mode, but need RGB mode + if 'keep_ori_im' in sample and sample['keep_ori_im']: + sample['ori_image'] = im + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + chip = sample['chip'] + x1, y1, x2, y2 = [int(xi) for xi in chip] + im = im[max(y1, 0):min(y2, im.shape[0]), + max(x1, 0):min(x2, im.shape[1]), :] + + sample['image'] = im + h = im.shape[0] + w = im.shape[1] + # sample['im_info'] = [h, w, 1.0] + sample['h'] = h + sample['w'] = w + + sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32) + sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) + return sample + + @register_op class Permute(BaseOperator): def __init__(self): diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 70dbf33522398666f9bc5c0ba95ba6eddbbeff92..df42a687c6c9cda73306e1b8c4a528489add71db 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -20,15 +20,19 @@ import os import sys import datetime import six +import copy +import json +import paddle import paddle.distributed as dist from ppdet.utils.checkpoint import save_model +from ppdet.metrics import get_infer_results from ppdet.utils.logger import setup_logger logger = setup_logger('ppdet.engine') -__all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer'] +__all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer', 'VisualDLWriter', 'SniperProposalsGenerator'] class Callback(object): @@ -47,6 +51,12 @@ class Callback(object): def on_epoch_end(self, status): pass + def on_train_begin(self, status): + pass + + def on_train_end(self, status): + pass + class ComposeCallback(object): def __init__(self, callbacks): @@ -72,6 +82,14 @@ class ComposeCallback(object): for c in self._callbacks: c.on_epoch_end(status) + def on_train_begin(self, status): + for c in self._callbacks: + c.on_train_begin(status) + + def on_train_end(self, status): + for c in self._callbacks: + c.on_train_end(status) + class LogPrinter(Callback): def __init__(self, model): @@ -256,3 +274,62 @@ class VisualDLWriter(Callback): map_value[0], self.vdl_mAP_step) self.vdl_mAP_step += 1 + + +class SniperProposalsGenerator(Callback): + def __init__(self, model): + super(SniperProposalsGenerator, self).__init__(model) + ori_dataset = self.model.dataset + self.dataset = self._create_new_dataset(ori_dataset) + self.loader = self.model.loader + self.cfg = self.model.cfg + self.infer_model = self.model.model + + def _create_new_dataset(self, ori_dataset): + dataset = copy.deepcopy(ori_dataset) + # init anno_cropper + dataset.init_anno_cropper() + # generate infer roidbs + ori_roidbs = dataset.get_ori_roidbs() + roidbs = dataset.anno_cropper.crop_infer_anno_records(ori_roidbs) + # set new roidbs + dataset.set_roidbs(roidbs) + + return dataset + + def _eval_with_loader(self, loader): + results = [] + with paddle.no_grad(): + self.infer_model.eval() + for step_id, data in enumerate(loader): + outs = self.infer_model(data) + for key in ['im_shape', 'scale_factor', 'im_id']: + outs[key] = data[key] + for key, value in outs.items(): + if hasattr(value, 'numpy'): + outs[key] = value.numpy() + + results.append(outs) + + return results + + def on_train_end(self, status): + self.loader.dataset = self.dataset + results = self._eval_with_loader(self.loader) + results = self.dataset.anno_cropper.aggregate_chips_detections(results) + # sniper + proposals = [] + clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} + for outs in results: + batch_res = get_infer_results(outs, clsid2catid) + start = 0 + for i, im_id in enumerate(outs['im_id']): + bbox_num = outs['bbox_num'] + end = start + bbox_num[i] + bbox_res = batch_res['bbox'][start:end] \ + if 'bbox' in batch_res else None + if bbox_res: + proposals += bbox_res + logger.info("save proposals in {}".format(self.cfg.proposals_path)) + with open(self.cfg.proposals_path, 'w') as f: + json.dump(proposals, f) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index d1bd1ed0ae34f5711ed29e96ea2b6bf662ad85d7..f6e12de606a83746a6b33e69b03a735403d6a2e2 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -35,12 +35,13 @@ from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval -from ppdet.metrics import RBoxMetric, JDEDetMetric +from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric +from ppdet.data.source.sniper_coco import SniperCOCODataSet from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats from ppdet.utils import profiler -from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter +from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter,SniperProposalsGenerator from .export_utils import _dump_infer_config, _prune_input_spec from ppdet.utils.logger import setup_logger @@ -137,6 +138,8 @@ class Trainer(object): self._callbacks = [LogPrinter(self), Checkpointer(self)] if self.cfg.get('use_vdl', False): self._callbacks.append(VisualDLWriter(self)) + if self.cfg.get('save_proposals', False): + self._callbacks.append(SniperProposalsGenerator(self)) self._compose_callback = ComposeCallback(self._callbacks) elif self.mode == 'eval': self._callbacks = [LogPrinter(self)] @@ -155,7 +158,7 @@ class Trainer(object): self._metrics = [] return classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False - if self.cfg.metric == 'COCO': + if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO": # TODO: bias should be unified bias = self.cfg['bias'] if 'bias' in self.cfg else 0 output_eval = self.cfg['output_eval'] \ @@ -170,22 +173,38 @@ class Trainer(object): # when do validation in train, annotation file should be get from # EvalReader instead of self.dataset(which is TrainReader) anno_file = self.dataset.get_anno() + dataset = self.dataset if self.mode == 'train' and validate: eval_dataset = self.cfg['EvalDataset'] eval_dataset.check_or_download_dataset() anno_file = eval_dataset.get_anno() + dataset = eval_dataset IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox' - self._metrics = [ - COCOMetric( - anno_file=anno_file, - clsid2catid=clsid2catid, - classwise=classwise, - output_eval=output_eval, - bias=bias, - IouType=IouType, - save_prediction_only=save_prediction_only) - ] + if self.cfg.metric == "COCO": + self._metrics = [ + COCOMetric( + anno_file=anno_file, + clsid2catid=clsid2catid, + classwise=classwise, + output_eval=output_eval, + bias=bias, + IouType=IouType, + save_prediction_only=save_prediction_only) + ] + elif self.cfg.metric == "SNIPERCOCO": # sniper + self._metrics = [ + SNIPERCOCOMetric( + anno_file=anno_file, + dataset=dataset, + clsid2catid=clsid2catid, + classwise=classwise, + output_eval=output_eval, + bias=bias, + IouType=IouType, + save_prediction_only=save_prediction_only + ) + ] elif self.cfg.metric == 'RBOX': # TODO: bias should be unified bias = self.cfg['bias'] if 'bias' in self.cfg else 0 @@ -342,6 +361,8 @@ class Trainer(object): self._flops(self.loader) profiler_options = self.cfg.get('profiler_options', None) + self._compose_callback.on_train_begin(self.status) + for epoch_id in range(self.start_epoch, self.cfg.epoch): self.status['mode'] = 'train' self.status['epoch_id'] = epoch_id @@ -424,6 +445,8 @@ class Trainer(object): if self.use_ema: self.model.set_dict(weight) + self._compose_callback.on_train_end(self.status) + def _eval_with_loader(self, loader): sample_num = 0 tic = time.time() @@ -479,6 +502,7 @@ class Trainer(object): self.model.eval() if self.cfg.get('print_flops', False): self._flops(loader) + results = [] for step_id, data in enumerate(loader): self.status['step_id'] = step_id # forward @@ -489,7 +513,12 @@ class Trainer(object): for key, value in outs.items(): if hasattr(value, 'numpy'): outs[key] = value.numpy() + results.append(outs) + # sniper + if type(self.dataset) == SniperCOCODataSet: + results = self.dataset.anno_cropper.aggregate_chips_detections(results) + for outs in results: batch_res = get_infer_results(outs, clsid2catid) bbox_num = outs['bbox_num'] diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index 65b18efd82eb6f47fa9e7c2da550663693b00e59..f9913b7fb81713d87e26e176f80b7d26848808d5 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -37,6 +37,7 @@ __all__ = [ 'WiderFaceMetric', 'get_infer_results', 'RBoxMetric', + 'SNIPERCOCOMetric' ] COCO_SIGMAS = np.array([ @@ -395,3 +396,37 @@ class RBoxMetric(Metric): def get_results(self): return {'bbox': [self.detection_map.get_map()]} + + +class SNIPERCOCOMetric(COCOMetric): + def __init__(self, anno_file, **kwargs): + super(SNIPERCOCOMetric, self).__init__(anno_file, **kwargs) + self.dataset = kwargs["dataset"] + self.chip_results = [] + + def reset(self): + # only bbox and mask evaluation support currently + self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []} + self.eval_results = {} + self.chip_results = [] + + def update(self, inputs, outputs): + outs = {} + # outputs Tensor -> numpy.ndarray + for k, v in outputs.items(): + outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v + + im_id = inputs['im_id'] + outs['im_id'] = im_id.numpy() if isinstance(im_id, + paddle.Tensor) else im_id + + self.chip_results.append(outs) + + + def accumulate(self): + results = self.dataset.anno_cropper.aggregate_chips_detections(self.chip_results) + for outs in results: + infer_results = get_infer_results(outs, self.clsid2catid, bias=self.bias) + self.results['bbox'] += infer_results['bbox'] if 'bbox' in infer_results else [] + + super(SNIPERCOCOMetric, self).accumulate() diff --git a/tools/sniper_params_stats.py b/tools/sniper_params_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..358aa63c5fa46812cccafda3cd2b42d65f89f02c --- /dev/null +++ b/tools/sniper_params_stats.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import json +import logging +import numpy as np + +from ppdet.utils.logger import setup_logger +logger = setup_logger('sniper_params_stats') + +def get_default_params(architecture): + """get_default_params""" + if architecture == "FasterRCNN": + anchor_range = np.array([64., 512.]) # for frcnn-fpn + # anchor_range = np.array([16., 373.]) # for yolov3 + # anchor_range = np.array([32., 373.]) # for yolov3 + default_crop_size = 1536 # mod 32 for frcnn-fpn + default_max_bbox_size = 352 + elif architecture == "YOLOv3": + anchor_range = np.array([32., 373.]) # for yolov3 + default_crop_size = 800 # mod 32 for yolov3 + default_max_bbox_size = 352 + else: + raise NotImplementedError + + return anchor_range, default_crop_size, default_max_bbox_size + + +def get_box_ratios(anno_file): + """ + get_size_ratios + :param anno_file: coco anno flile + :return: size_ratio: (box_long_size / pic_long_size) + """ + coco_dict = json.load(open(anno_file)) + image_list = coco_dict['images'] + anno_list = coco_dict['annotations'] + + image_id2hw = {} + for im_dict in image_list: + im_id = im_dict['id'] + h, w = im_dict['height'], im_dict['width'] + image_id2hw[im_id] = (h, w) + + box_ratios = [] + for a_dict in anno_list: + im_id = a_dict['image_id'] + im_h, im_w = image_id2hw[im_id] + bbox = a_dict['bbox'] + x1, y1, w, h = bbox + pic_long = max(im_h, im_w) + box_long = max(w, h) + box_ratios.append(box_long / pic_long) + + return np.array(box_ratios) + + +def get_target_size_and_valid_box_ratios(anchor_range, box_ratio_p2, box_ratio_p98): + """get_scale_and_ratios""" + anchor_better_low, anchor_better_high = anchor_range # (60., 512.) + anchor_center = np.sqrt(anchor_better_high * anchor_better_low) + + anchor_log_range = np.log10(anchor_better_high) - np.log10(anchor_better_low) + box_ratio_log_range = np.log10(box_ratio_p98) - np.log10(box_ratio_p2) + logger.info("anchor_log_range:{}, box_ratio_log_range:{}".format(anchor_log_range, box_ratio_log_range)) + + box_cut_num = int(np.ceil(box_ratio_log_range / anchor_log_range)) + box_ratio_log_window = box_ratio_log_range / box_cut_num + logger.info("box_cut_num:{}, box_ratio_log_window:{}".format(box_cut_num, box_ratio_log_window)) + + image_target_sizes = [] + valid_ratios = [] + for i in range(box_cut_num): + # # method1: align center + # box_ratio_log_center = np.log10(p2) + 0.5 * box_ratio_log_window + i * box_ratio_log_window + # box_ratio_center = np.power(10, box_ratio_log_center) + # scale = anchor_center / box_ratio_center + # method2: align left low + box_ratio_low = np.power(10, np.log10(box_ratio_p2) + i * box_ratio_log_window) + image_target_size = anchor_better_low / box_ratio_low + + image_target_sizes.append(int(image_target_size)) + valid_ratio = anchor_range / image_target_size + valid_ratios.append(valid_ratio.tolist()) + + logger.info("Box cut {}".format(i)) + logger.info("box_ratio_low: {}".format(box_ratio_low)) + logger.info("image_target_size: {}".format(image_target_size)) + logger.info("valid_ratio: {}".format(valid_ratio)) + + return image_target_sizes, valid_ratios + + +def get_valid_ranges(valid_ratios): + """ + get_valid_box_ratios_range + :param valid_ratios: + :return: + """ + valid_ranges = [] + if len(valid_ratios) == 1: + valid_ranges.append([-1, -1]) + else: + for i, vratio in enumerate(valid_ratios): + if i == 0: + valid_ranges.append([-1, vratio[1]]) + elif i == len(valid_ratios) - 1: + valid_ranges.append([vratio[0], -1]) + else: + valid_ranges.append(vratio) + return valid_ranges + + +def get_percentile(a_array, low_percent, high_percent): + """ + get_percentile + :param low_percent: + :param high_percent: + :return: + """ + array_p0 = min(a_array) + array_p100 = max(a_array) + array_plow = np.percentile(a_array, low_percent) + array_phigh = np.percentile(a_array, high_percent) + logger.info( + "array_percentile(0): {},array_percentile low({}): {}, " + "array_percentile high({}): {}, array_percentile 100: {}".format( + array_p0, low_percent, array_plow, high_percent, array_phigh, array_p100)) + return array_plow, array_phigh + + +def sniper_anno_stats(architecture, anno_file): + """ + sniper_anno_stats + :param anno_file: + :return: + """ + + anchor_range, default_crop_size, default_max_bbox_size = get_default_params(architecture) + + box_ratios = get_box_ratios(anno_file) + + box_ratio_p8, box_ratio_p92 = get_percentile(box_ratios, 8, 92) + + image_target_sizes, valid_box_ratios = get_target_size_and_valid_box_ratios(anchor_range, box_ratio_p8, box_ratio_p92) + + valid_ranges = get_valid_ranges(valid_box_ratios) + + crop_size = min(default_crop_size, min([item for item in image_target_sizes])) + crop_size = int(np.ceil(crop_size / 32.) * 32.) + crop_stride = max(min(default_max_bbox_size, crop_size), crop_size - default_max_bbox_size) + logger.info("Result".center(100, '-')) + logger.info("image_target_sizes: {}".format(image_target_sizes)) + logger.info("valid_box_ratio_ranges: {}".format(valid_ranges)) + logger.info("chip_target_size: {}, chip_target_stride: {}".format(crop_size, crop_stride)) + + return { + "image_target_sizes": image_target_sizes, + "valid_box_ratio_ranges": valid_ranges, + "chip_target_size": crop_size, + "chip_target_stride": crop_stride + } + +if __name__=="__main__": + architecture, anno_file = sys.argv[1], sys.argv[2] + sniper_anno_stats(architecture, anno_file) diff --git a/tools/train.py b/tools/train.py index 08a3f26f91df94505c3125734b0f7cd09c002713..878aa60fac7d89effc2a5b38832a56105936baa5 100755 --- a/tools/train.py +++ b/tools/train.py @@ -88,6 +88,17 @@ def parse_args(): help="The option of profiler, which should be in " "format \"key1=value1;key2=value2;key3=value3\"." "please see ppdet/utils/profiler.py for detail.") + parser.add_argument( + '--save_proposals', + action='store_true', + default=False, + help='Whether to save the train proposals') + parser.add_argument( + '--proposals_path', + type=str, + default="sniper/proposals.json", + help='Train proposals directory') + args = parser.parse_args() return args @@ -125,6 +136,8 @@ def main(): cfg['vdl_log_dir'] = FLAGS.vdl_log_dir cfg['save_prediction_only'] = FLAGS.save_prediction_only cfg['profiler_options'] = FLAGS.profiler_options + cfg['save_proposals'] = FLAGS.save_proposals + cfg['proposals_path'] = FLAGS.proposals_path merge_config(FLAGS.opt) # disable npu in config by default