未验证 提交 0f75cb6c 编写于 作者: F Feng Ni 提交者: GitHub

[smalldet] add SAHI slice train eval infer deploy (#6465)

* add slice infer for smalldet

* add slice infer for smalldet

* fix slice infer

* fix doc, test=document_fix

* fix eval and configs

* add slice dataset loader

* fix deploy

* fix docs, test=document_fix

* update docs, test=document_fix
上级 4cecb62c
# PP-YOLOE Smalldet 检测模型 # PP-YOLOE 小目标检测模型
PaddleDetection团队提供了针对VisDrone-DET、DOTA水平框、Xview等小目标场景数据集的基于PP-YOLOE的检测模型,以及提供了一套使用[SAHI](https://github.com/obss/sahi)(Slicing Aided Hyper Inference)工具切图和拼图的方案,用户可以下载模型进行使用。
<img src="https://user-images.githubusercontent.com/82303451/182520025-f6bd1c76-a9f9-4f8c-af9b-b37a403258d8.png" title="VisDrone" alt="VisDrone" width="300"><img src="https://user-images.githubusercontent.com/82303451/182521833-4aa0314c-b3f2-4711-9a65-cabece612737.png" title="VisDrone" alt="VisDrone" width="300"><img src="https://user-images.githubusercontent.com/82303451/182520038-cacd5d09-0b85-475c-8e59-72f1fc48eef8.png" title="DOTA" alt="DOTA" height="168"><img src="https://user-images.githubusercontent.com/82303451/182524123-dcba55a2-ce2d-4ba1-9d5b-eb99cb440715.jpeg" title="Xview" alt="Xview" height="168"> <img src="https://user-images.githubusercontent.com/82303451/182520025-f6bd1c76-a9f9-4f8c-af9b-b37a403258d8.png" title="VisDrone" alt="VisDrone" width="300"><img src="https://user-images.githubusercontent.com/82303451/182521833-4aa0314c-b3f2-4711-9a65-cabece612737.png" title="VisDrone" alt="VisDrone" width="300"><img src="https://user-images.githubusercontent.com/82303451/182520038-cacd5d09-0b85-475c-8e59-72f1fc48eef8.png" title="DOTA" alt="DOTA" height="168"><img src="https://user-images.githubusercontent.com/82303451/182524123-dcba55a2-ce2d-4ba1-9d5b-eb99cb440715.jpeg" title="Xview" alt="Xview" height="168">
## 基础模型:
| 模型 | 数据集 | SLICE_SIZE | OVERLAP_RATIO | 类别数 | mAP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | 下载链接 | 配置文件 | | 模型 | 数据集 | SLICE_SIZE | OVERLAP_RATIO | 类别数 | mAP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | 下载链接 | 配置文件 |
|:---------|:---------------:|:---------------:|:---------------:|:------:|:-----------------------:|:-------------------:|:---------:| :-----: | |:---------|:---------------:|:---------------:|:---------------:|:------:|:-----------------------:|:-------------------:|:---------:| :-----: |
|PP-YOLOE-l| Xview | 400 | 0.25 | 60 | 14.5 | 26.8 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_xview_400_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_xview_400_025.yml) | |PP-YOLOE-P2-l| DOTA | 500 | 0.25 | 15 | 53.9 | 78.6 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_p2_crn_l_80e_sliced_DOTA_500_025.pdparams) | [配置文件](./ppyoloe_p2_crn_l_80e_sliced_DOTA_500_025.yml) |
|PP-YOLOE-l| DOTA | 500 | 0.25 | 15 | 46.8 | 72.6 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_dota_500_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_DOTA_500_025.yml) | |PP-YOLOE-P2-l| Xview | 400 | 0.25 | 60 | 14.9 | 27.0 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_p2_crn_l_80e_sliced_xview_400_025.pdparams) | [配置文件](./ppyoloe_p2_crn_l_80e_sliced_xview_400_025.yml) |
|PP-YOLOE-l| VisDrone | 500 | 0.25 | 10 | 29.7 | 48.5 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) | |PP-YOLOE-l| VisDrone-DET| 640 | 0.25 | 10 | 38.5 | 60.2 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) |
## 原图评估和拼图评估对比:
| 模型 | 数据集 | SLICE_SIZE | OVERLAP_RATIO | 类别数 | mAP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | 下载链接 | 配置文件 |
|:---------|:---------------:|:---------------:|:---------------:|:------:|:-----------------------:|:-------------------:|:---------:| :-----: |
|PP-YOLOE-l| VisDrone-DET| 640 | 0.25 | 10 | 29.7 | 48.5 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) |
|PP-YOLOE-l (Assembled)| VisDrone-DET| 640 | 0.25 | 10 | 37.2 | 59.4 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](./ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) |
**注意:** **注意:**
- **SLICE_SIZE**表示使用SAHI工具切图后子图的大小(SLICE_SIZE*SLICE_SIZE);**OVERLAP_RATIO**表示切图重叠率。 - 使用[SAHI](https://github.com/obss/sahi)切图工具需要首先安装:`pip install sahi`,参考[installation](https://github.com/obss/sahi/blob/main/README.md#installation)
- **SLICE_SIZE**表示使用SAHI工具切图后子图的边长大小,**OVERLAP_RATIO**表示切图的子图之间的重叠率,DOTA水平框和Xview数据集均是切图后训练,AP指标为切图后的子图val上的指标。
- VisDrone-DET数据集请参照[visdrone](../visdrone),可使用原图训练,也可使用切图后训练。
- PP-YOLOE模型训练过程中使用8 GPUs进行混合精度训练,如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)** 调整学习率。 - PP-YOLOE模型训练过程中使用8 GPUs进行混合精度训练,如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)** 调整学习率。
- 具体使用教程请参考[ppyoloe](../ppyoloe#getting-start) - 常用训练验证部署等步骤请参考[ppyoloe](../ppyoloe#getting-start)
- 自动切图和拼图的推理预测需添加设置`--slice_infer`,具体见下文使用说明。
- Assembled表示自动切图和拼图。
# 使用说明
## 1.训练
首先将你的数据集为COCO数据集格式,然后使用SAHI切图工具进行离线切图,对保存的子图按常规检测模型的训练流程走即可。
也可直接下载PaddleDetection团队提供的切图后的VisDrone-DET、DOTA水平框、Xview数据集。
执行以下指令使用混合精度训练PP-YOLOE
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml --amp --eval
```
**注意:**
- 使用默认配置训练需要设置`--amp`以避免显存溢出。
## 2.评估
### 2.1 子图评估:
默认评估方式是子图评估,子图数据集的验证集设置为:
```
EvalDataset:
!COCODataSet
image_dir: val_images_640_025
anno_path: val_640_025.json
dataset_dir: dataset/visdrone_sliced
```
按常规检测模型的评估流程,评估提前切好并存下来的子图上的精度:
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams
```
### 2.2 原图评估:
修改验证集的标注文件路径为原图标注文件:
```
EvalDataset:
!COCODataSet
image_dir: VisDrone2019-DET-val
anno_path: val.json
dataset_dir: dataset/visdrone
```
直接评估原图上的精度:
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams
```
### 2.3 子图拼图评估:
修改验证集的标注文件路径为原图标注文件:
```
# very slow, preferly eval with a determined weights(xx.pdparams)
# if you want to eval during training, change SlicedCOCODataSet to COCODataSet and delete sliced_size and overlap_ratio
EvalDataset:
!SlicedCOCODataSet
image_dir: VisDrone2019-DET-val
anno_path: val.json
dataset_dir: dataset/visdrone
sliced_size: [640, 640]
overlap_ratio: [0.25, 0.25]
```
会在评估过程中自动对原图进行切图最后再重组和融合结果来评估原图上的精度:
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025_slice_infer.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams --slice_infer --combine_method=nms --match_threshold=0.6 --match_metric=ios
```
- 设置`--slice_infer`表示切图预测并拼装重组结果,如果不使用则不写;
- 设置`--slice_size`表示切图的子图尺寸大小,设置`--overlap_ratio`表示子图间重叠率;
- 设置`--combine_method`表示子图结果重组去重的方式,默认是`nms`
- 设置`--match_threshold`表示子图结果重组去重的阈值,默认是0.6;
- 设置`--match_metric`表示子图结果重组去重的度量标准,默认是`ios`表示交小比(两个框交集面积除以更小框的面积),也可以选择交并比`iou`(两个框交集面积除以并集面积),精度效果因数据集而而异,但选择`ios`预测速度会更快一点;
**注意:**
- 设置`--slice_infer`表示切图预测并拼装重组结果,如果不使用则不写,注意需要确保EvalDataset的数据集类是选用的SlicedCOCODataSet而不是COCODataSet;
- 可以自行修改选择合适的子图尺度sliced_size和子图间重叠率overlap_ratio,如:
```
EvalDataset:
!SlicedCOCODataSet
image_dir: VisDrone2019-DET-val
anno_path: val.json
dataset_dir: dataset/visdrone
sliced_size: [480, 480]
overlap_ratio: [0.2, 0.2]
```
- 设置`--combine_method`表示子图结果重组去重的方式,默认是`nms`
- 设置`--match_threshold`表示子图结果重组去重的阈值,默认是0.6;
- 设置`--match_metric`表示子图结果重组去重的度量标准,默认是`ios`表示交小比(两个框交集面积除以更小框的面积),也可以选择交并比`iou`(两个框交集面积除以并集面积),精度效果因数据集而而异,但选择`ios`预测速度会更快一点;
## 3.预测
### 3.1 子图或原图直接预测:
与评估流程基本相同,可以在提前切好并存下来的子图上预测,也可以对原图预测,如:
```bash
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams --infer_img=demo.jpg --draw_threshold=0.25
```
### 3.2 原图自动切图并拼图预测:
也可以对原图进行自动切图并拼图重组来预测原图,如:
```bash
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams --infer_img=demo.jpg --draw_threshold=0.25 --slice_infer --slice_size 640 640 --overlap_ratio 0.25 0.25 --combine_method=nms --match_threshold=0.6 --match_metric=ios
```
- 设置`--slice_infer`表示切图预测并拼装重组结果,如果不使用则不写;
- 设置`--slice_size`表示切图的子图尺寸大小,设置`--overlap_ratio`表示子图间重叠率;
- 设置`--combine_method`表示子图结果重组去重的方式,默认是`nms`
- 设置`--match_threshold`表示子图结果重组去重的阈值,默认是0.6;
- 设置`--match_metric`表示子图结果重组去重的度量标准,默认是`ios`表示交小比(两个框交集面积除以更小框的面积),也可以选择交并比`iou`(两个框交集面积除以并集面积),精度效果因数据集而而异,但选择`ios`预测速度会更快一点;
## 4.部署
### 4.1 导出模型
```bash
# export model
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams
```
### 4.2 使用原图或子图直接推理:
```bash
# deploy infer
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyoloe_crn_l_80e_sliced_visdrone_640_025 --image_file=demo.jpg --device=GPU --threshold=0.25
```
### 4.3 使用原图自动切图并拼图重组结果来推理:
```bash
# deploy slice infer
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyoloe_crn_l_80e_sliced_visdrone_640_025 --image_file=demo.jpg --device=GPU --threshold=0.25 --slice_infer --slice_size 640 640 --overlap_ratio 0.25 0.25 --combine_method=nms --match_threshold=0.6 --match_metric=ios
```
- 设置`--slice_infer`表示切图预测并拼装重组结果,如果不使用则不写;
- 设置`--slice_size`表示切图的子图尺寸大小,设置`--overlap_ratio`表示子图间重叠率;
- 设置`--combine_method`表示子图结果重组去重的方式,默认是`nms`
- 设置`--match_threshold`表示子图结果重组去重的阈值,默认是0.6;
- 设置`--match_metric`表示子图结果重组去重的度量标准,默认是`ios`表示交小比(两个框交集面积除以更小框的面积),也可以选择交并比`iou`(两个框交集面积除以并集面积),精度效果因数据集而而异,但选择`ios`预测速度会更快一点;
# SAHI切图工具使用说明
## 1. 数据集下载
### VisDrone-DET
VisDrone-DET是一个无人机航拍场景的小目标数据集,整理后的COCO格式VisDrone-DET数据集[下载链接](https://bj.bcebos.com/v1/paddledet/data/smalldet/visdrone.zip),切图后的COCO格式数据集[下载链接](https://bj.bcebos.com/v1/paddledet/data/smalldet/visdrone_sliced.zip),检测其中的**10类**,包括 `pedestrian(1), people(2), bicycle(3), car(4), van(5), truck(6), tricycle(7), awning-tricycle(8), bus(9), motor(10)`,原始数据集[下载链接](https://github.com/VisDrone/VisDrone-Dataset)
具体使用和下载请参考[visdrone](../visdrone)
### DOTA水平框:
DOTA是一个大型的遥感影像公开数据集,这里使用**DOTA-v1.0**水平框数据集,切图后整理的COCO格式的DOTA水平框数据集[下载链接](https://bj.bcebos.com/v1/paddledet/data/smalldet/dota_sliced.zip),检测其中的**15类**
包括 `plane(0), baseball-diamond(1), bridge(2), ground-track-field(3), small-vehicle(4), large-vehicle(5), ship(6), tennis-court(7),basketball-court(8), storage-tank(9), soccer-ball-field(10), roundabout(11), harbor(12), swimming-pool(13), helicopter(14)`
图片及原始数据集[下载链接](https://captain-whu.github.io/DOAI2019/dataset.html)
### Xview:
Xview是一个大型的航拍遥感检测数据集,目标极小极多,切图后整理的COCO格式数据集[下载链接](https://bj.bcebos.com/v1/paddledet/data/smalldet/xview_sliced.zip),检测其中的**60类**
具体类别为:
<details>
`Fixed-wing Aircraft(0),
Small Aircraft(1),
Cargo Plane(2),
Helicopter(3),
Passenger Vehicle(4),
Small Car(5),
Bus(6),
Pickup Truck(7),
Utility Truck(8),
Truck(9),
Cargo Truck(10),
Truck w/Box(11),
Truck Tractor(12),
Trailer(13),
Truck w/Flatbed(14),
Truck w/Liquid(15),
Crane Truck(16),
Railway Vehicle(17),
Passenger Car(18),
Cargo Car(19),
Flat Car(20),
Tank car(21),
Locomotive(22),
Maritime Vessel(23),
Motorboat(24),
Sailboat(25),
Tugboat(26),
Barge(27),
Fishing Vessel(28),
Ferry(29),
Yacht(30),
Container Ship(31),
Oil Tanker(32),
Engineering Vehicle(33),
Tower crane(34),
Container Crane(35),
Reach Stacker(36),
Straddle Carrier(37),
Mobile Crane(38),
Dump Truck(39),
Haul Truck(40),
Scraper/Tractor(41),
Front loader/Bulldozer(42),
Excavator(43),
Cement Mixer(44),
Ground Grader(45),
Hut/Tent(46),
Shed(47),
Building(48),
Aircraft Hangar(49),
Damaged Building(50),
Facility(51),
Construction Site(52),
Vehicle Lot(53),
Helipad(54),
Storage Tank(55),
Shipping container lot(56),
Shipping Container(57),
Pylon(58),
Tower(59)
`
</details>
,原始数据集[下载链接](https://challenge.xviewdataset.org/download-links)
## 2. 统计数据集分布
首先统计所用数据集标注框的平均宽高占图片真实宽高的比例分布:
```bash
python slice_tools/box_distribution.py --json_path ../../dataset/DOTA/annotations/train.json --out_img box_distribution.jpg
```
- `--json_path` :待统计数据集COCO 格式 annotation 的json文件路径
- `--out_img` :输出的统计分布图路径
以DOTA数据集的train数据集为例,统计结果打印如下:
```bash
Median of ratio_w is 0.03799439775910364
Median of ratio_h is 0.04074914637387802
all_img with box: 1409
all_ann: 98905
Distribution saved as box_distribution.jpg
```
**注意:**
- 当原始数据集全部有标注框的图片中,**有1/2以上的图片标注框的平均宽高与原图宽高比例小于0.04时**,建议进行切图训练。
## 3. SAHI切图
针对需要切图的数据集,使用[SAHI](https://github.com/obss/sahi)库进行切分:
### 安装SAHI库:
参考[SAHI installation](https://github.com/obss/sahi/blob/main/README.md#installation)进行安装
```bash
pip install sahi
```
### 基于SAHI切图:
```bash
python slice_tools/slice_image.py --image_dir ../../dataset/DOTA/train/ --json_path ../../dataset/DOTA/annotations/train.json --output_dir ../../dataset/dota_sliced --slice_size 500 --overlap_ratio 0.25
```
- `--image_dir`:原始数据集图片文件夹的路径
- `--json_path`:原始数据集COCO格式的json标注文件的路径
- `--output_dir`:切分后的子图及其json标注文件保存的路径
- `--slice_size`:切分以后子图的边长尺度大小(默认切图后为正方形)
- `--overlap_ratio`:切分时的子图之间的重叠率
- 以上述代码为例,切分后的子图文件夹与json标注文件共同保存在`dota_sliced`文件夹下,分别命名为`train_images_500_025``train_500_025.json`
# 引用
```
@article{akyon2022sahi,
title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection},
author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin},
journal={arXiv preprint arXiv:2202.06934},
year={2022}
}
@inproceedings{xia2018dota,
title={DOTA: A large-scale dataset for object detection in aerial images},
author={Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={3974--3983},
year={2018}
}
@ARTICLE{9573394,
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Detection and Tracking Meet Drones Challenge},
year={2021},
volume={},
number={},
pages={1-1},
doi={10.1109/TPAMI.2021.3119563}
}
```
...@@ -3,18 +3,18 @@ num_classes: 15 ...@@ -3,18 +3,18 @@ num_classes: 15
TrainDataset: TrainDataset:
!COCODataSet !COCODataSet
image_dir: DOTA_slice_train/train_images_500_025 image_dir: train_images_500_025
anno_path: DOTA_slice_train/train_500_025.json anno_path: train_500_025.json
dataset_dir: dataset/DOTA dataset_dir: dataset/dota_sliced
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
image_dir: DOTA_slice_val/val_images_500_025 image_dir: val_images_500_025
anno_path: DOTA_slice_val/val_500_025.json anno_path: val_500_025.json
dataset_dir: dataset/DOTA dataset_dir: dataset/dota_sliced
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: dataset/DOTA/DOTA_slice_val/val_500_025.json anno_path: val_500_025.json
dataset_dir: dataset/DOTA/DOTA_slice_val/val_images_500_025 dataset_dir: dataset/dota_sliced
...@@ -16,5 +16,5 @@ EvalDataset: ...@@ -16,5 +16,5 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: dataset/visdrone_sliced/val_640_025.json anno_path: val_640_025.json
dataset_dir: dataset/visdrone_sliced/val_images_640_025 dataset_dir: dataset/visdrone_sliced
...@@ -5,16 +5,16 @@ TrainDataset: ...@@ -5,16 +5,16 @@ TrainDataset:
!COCODataSet !COCODataSet
image_dir: train_images_400_025 image_dir: train_images_400_025
anno_path: train_400_025.json anno_path: train_400_025.json
dataset_dir: dataset/xview/xview_slic dataset_dir: dataset/xview_sliced
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
image_dir: val_images_400_025 image_dir: val_images_400_025
anno_path: val_400_025.json anno_path: val_400_025.json
dataset_dir: dataset/xview/xview_slic dataset_dir: dataset/xview_sliced
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: dataset/xview/xview_slic/val_400_025.json anno_path: val_400_025.json
dataset_dir: dataset/xview/xview_slic/val_images_400_025 dataset_dir: dataset/xview_sliced
...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco ...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco
depth_mult: 1.0 depth_mult: 1.0
width_mult: 1.0 width_mult: 1.0
TrainReader: TrainReader:
batch_size: 8 batch_size: 8
EvalReader:
batch_size: 1
epoch: 80 epoch: 80
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
...@@ -34,3 +39,10 @@ PPYOLOEHead: ...@@ -34,3 +39,10 @@ PPYOLOEHead:
keep_top_k: 500 keep_top_k: 500
score_threshold: 0.01 score_threshold: 0.01
nms_threshold: 0.6 nms_threshold: 0.6
EvalDataset:
!COCODataSet
image_dir: val_images_640_025
anno_path: val_640_025.json
dataset_dir: dataset/visdrone_sliced
_BASE_: [
'./_base_/visdrone_sliced_640_025_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_300e.yml',
'../ppyoloe/_base_/ppyoloe_crn.yml',
'../ppyoloe/_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_80e_sliced_visdrone_640_025/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
depth_mult: 1.0
width_mult: 1.0
TrainReader:
batch_size: 8
EvalReader:
batch_size: 1 # only support bs=1 when slice infer
epoch: 80
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 1
PPYOLOEHead:
static_assigner_epoch: -1
nms:
name: MultiClassNMS
nms_top_k: 10000
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.6
# very slow, preferly eval with a determined weights(xx.pdparams)
# if you want to eval during training, change SlicedCOCODataSet to COCODataSet and delete sliced_size/overlap_ratio
EvalDataset:
!SlicedCOCODataSet
image_dir: VisDrone2019-DET-val
anno_path: val.json
dataset_dir: dataset/visdrone
sliced_size: [640, 640]
overlap_ratio: [0.25, 0.25]
...@@ -7,14 +7,27 @@ _BASE_: [ ...@@ -7,14 +7,27 @@ _BASE_: [
] ]
log_iter: 100 log_iter: 100
snapshot_epoch: 10 snapshot_epoch: 10
weights: output/ppyoloe_crn_l_80e_sliced_DOTA_500_025/model_final weights: output/ppyoloe_p2_crn_l_80e_sliced_DOTA_500_025/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
depth_mult: 1.0 depth_mult: 1.0
width_mult: 1.0 width_mult: 1.0
CSPResNet:
return_idx: [0, 1, 2, 3]
use_alpha: True
CustomCSPPAN:
out_channels: [768, 384, 192, 64]
TrainReader: TrainReader:
batch_size: 8 batch_size: 4
EvalReader:
batch_size: 1
epoch: 80 epoch: 80
LearningRate: LearningRate:
...@@ -27,6 +40,7 @@ LearningRate: ...@@ -27,6 +40,7 @@ LearningRate:
epochs: 1 epochs: 1
PPYOLOEHead: PPYOLOEHead:
fpn_strides: [32, 16, 8, 4]
static_assigner_epoch: -1 static_assigner_epoch: -1
nms: nms:
name: MultiClassNMS name: MultiClassNMS
......
...@@ -7,14 +7,27 @@ _BASE_: [ ...@@ -7,14 +7,27 @@ _BASE_: [
] ]
log_iter: 100 log_iter: 100
snapshot_epoch: 10 snapshot_epoch: 10
weights: output/ppyoloe_crn_l_80e_sliced_xview_400_025/model_final weights: output/ppyoloe_p2_crn_l_80e_sliced_xview_400_025/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
depth_mult: 1.0 depth_mult: 1.0
width_mult: 1.0 width_mult: 1.0
CSPResNet:
return_idx: [0, 1, 2, 3]
use_alpha: True
CustomCSPPAN:
out_channels: [768, 384, 192, 64]
TrainReader: TrainReader:
batch_size: 8 batch_size: 4
EvalReader:
batch_size: 1
epoch: 80 epoch: 80
LearningRate: LearningRate:
...@@ -27,6 +40,7 @@ LearningRate: ...@@ -27,6 +40,7 @@ LearningRate:
epochs: 1 epochs: 1
PPYOLOEHead: PPYOLOEHead:
fpn_strides: [32, 16, 8, 4]
static_assigner_epoch: -1 static_assigner_epoch: -1
nms: nms:
name: MultiClassNMS name: MultiClassNMS
......
...@@ -18,11 +18,13 @@ PaddleDetection团队提供了针对VisDrone-DET小目标数航拍场景的基 ...@@ -18,11 +18,13 @@ PaddleDetection团队提供了针对VisDrone-DET小目标数航拍场景的基
|PP-YOLOE-Alpha-largesize-l| 41.9 | 65.0 | 32.3 | 53.0 | 37.13 | 61.15 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_alpha_largesize_80e_visdrone.yml) | |PP-YOLOE-Alpha-largesize-l| 41.9 | 65.0 | 32.3 | 53.0 | 37.13 | 61.15 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_alpha_largesize_80e_visdrone.yml) |
|PP-YOLOE-P2-Alpha-largesize-l| 41.3 | 64.5 | 32.4 | 53.1 | 37.49 | 51.54 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.yml) | |PP-YOLOE-P2-Alpha-largesize-l| 41.3 | 64.5 | 32.4 | 53.1 | 37.49 | 51.54 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.yml) |
## 切图训练:
| 模型 | COCOAPI mAP<sup>val<br>0.5:0.95 | COCOAPI mAP<sup>val<br>0.5 | COCOAPI mAP<sup>test_dev<br>0.5:0.95 | COCOAPI mAP<sup>test_dev<br>0.5 | MatlabAPI mAP<sup>test_dev<br>0.5:0.95 | MatlabAPI mAP<sup>test_dev<br>0.5 | 下载 | 配置文件 | ## 原图评估和拼图评估对比:
|:---------|:------:|:------:| :----: | :------:| :------: | :------:| :----: | :------:|
|PP-YOLOE-l| 29.7 | 48.5 | 23.3 | 39.9 | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](../smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) | | 模型 | 数据集 | SLICE_SIZE | OVERLAP_RATIO | 类别数 | mAP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | 下载链接 | 配置文件 |
|:---------|:---------------:|:---------------:|:---------------:|:------:|:-----------------------:|:-------------------:|:---------:| :-----: |
|PP-YOLOE-l| VisDrone-DET| 640 | 0.25 | 10 | 29.7 | 48.5 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](../smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) |
|PP-YOLOE-l (Assembled)| VisDrone-DET| 640 | 0.25 | 10 | 37.2 | 59.4 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams) | [配置文件](../smalldet/ppyoloe_crn_l_80e_sliced_visdrone_640_025.yml) |
**注意:** **注意:**
......
...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco ...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco
depth_mult: 1.0 depth_mult: 1.0
width_mult: 1.0 width_mult: 1.0
TrainReader: TrainReader:
batch_size: 8 batch_size: 8
EvalReader:
batch_size: 1
epoch: 80 epoch: 80
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
_BASE_: [ _BASE_: [
'ppyoloe_crn_l_80e_visdrone.yml', 'ppyoloe_crn_l_80e_visdrone.yml',
] ]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_alpha_largesize_80e_visdrone/model_final weights: output/ppyoloe_crn_l_alpha_largesize_80e_visdrone/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
...@@ -42,7 +44,7 @@ EvalReader: ...@@ -42,7 +44,7 @@ EvalReader:
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2} - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 2 batch_size: 1
TestReader: TestReader:
inputs_def: inputs_def:
......
_BASE_: [ _BASE_: [
'ppyoloe_crn_l_80e_visdrone.yml', 'ppyoloe_crn_l_80e_visdrone.yml',
] ]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_p2_alpha_80e_visdrone/model_final weights: output/ppyoloe_crn_l_p2_alpha_80e_visdrone/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
...@@ -8,6 +10,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco ...@@ -8,6 +10,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco
TrainReader: TrainReader:
batch_size: 4 batch_size: 4
EvalReader:
batch_size: 1
LearningRate: LearningRate:
base_lr: 0.005 base_lr: 0.005
......
_BASE_: [ _BASE_: [
'ppyoloe_crn_l_80e_visdrone.yml', 'ppyoloe_crn_l_80e_visdrone.yml',
] ]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone/model_final weights: output/ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
...@@ -49,7 +51,7 @@ EvalReader: ...@@ -49,7 +51,7 @@ EvalReader:
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2} - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 2 batch_size: 1
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco ...@@ -13,9 +13,14 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco
depth_mult: 0.33 depth_mult: 0.33
width_mult: 0.50 width_mult: 0.50
TrainReader: TrainReader:
batch_size: 8 batch_size: 8
EvalReader:
batch_size: 1
epoch: 80 epoch: 80
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
_BASE_: [ _BASE_: [
'ppyoloe_crn_s_80e_visdrone.yml', 'ppyoloe_crn_s_80e_visdrone.yml',
] ]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_s_p2_alpha_80e_visdrone/model_final weights: output/ppyoloe_crn_s_p2_alpha_80e_visdrone/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
TrainReader: TrainReader:
batch_size: 4 batch_size: 4
EvalReader:
batch_size: 1
LearningRate: LearningRate:
base_lr: 0.005 base_lr: 0.005
......
...@@ -36,7 +36,7 @@ from picodet_postprocess import PicoDetPostProcess ...@@ -36,7 +36,7 @@ from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms
# Global dictionary # Global dictionary
SUPPORT_MODELS = { SUPPORT_MODELS = {
...@@ -218,6 +218,93 @@ class Detector(object): ...@@ -218,6 +218,93 @@ class Detector(object):
def get_timer(self): def get_timer(self):
return self.det_times return self.det_times
def predict_image_slice(self,
img_list,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25],
combine_method='nms',
match_threshold=0.6,
match_metric='iou',
visual=True,
save_file=None):
# slice infer only support bs=1
results = []
try:
import sahi
from sahi.slicing import slice_image
except Exception as e:
logger.error(
'sahi not found, plaese install sahi. '
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
)
raise e
num_classes = len(self.pred_config.labels)
for i in range(len(img_list)):
ori_image = img_list[i]
slice_image_result = sahi.slicing.slice_image(
image=ori_image,
slice_height=slice_size[0],
slice_width=slice_size[1],
overlap_height_ratio=overlap_ratio[0],
overlap_width_ratio=overlap_ratio[1])
sub_img_num = len(slice_image_result)
merged_bboxs = []
for _ind in range(sub_img_num):
im = slice_image_result.images[_ind]
self.det_times.preprocess_time_s.start()
inputs = self.preprocess([im]) # should be list
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
shift_amount = slice_image_result.starting_pixels[_ind]
result['boxes'][:, 2:4] = result['boxes'][:, 2:4] + shift_amount
result['boxes'][:, 4:6] = result['boxes'][:, 4:6] + shift_amount
merged_bboxs.append(result['boxes'])
merged_results = {'boxes': []}
if combine_method == 'nms':
final_boxes = multiclass_nms(
np.concatenate(merged_bboxs), num_classes, match_threshold,
match_metric)
merged_results['boxes'] = np.concatenate(final_boxes)
elif combine_method == 'concat':
merged_results['boxes'] = np.concatenate(merged_bboxs)
else:
raise ValueError(
"Now only support 'nms' or 'concat' to fuse detection results."
)
merged_results['boxes_num'] = np.array(
[len(merged_results['boxes'])], dtype=np.int32)
if visual:
visualize(
[ori_image], # should be list
merged_results,
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
results.append(merged_results)
if visual:
print('Test iter {}'.format(i))
if save_file is not None:
Path(self.output_dir).mkdir(exist_ok=True)
self.format_coco_results(image_list, results, save_file=save_file)
results = self.merge_batch_result(results)
return results
def predict_image(self, def predict_image(self,
image_list, image_list,
run_benchmark=False, run_benchmark=False,
...@@ -871,6 +958,16 @@ def main(): ...@@ -871,6 +958,16 @@ def main():
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
save_file = os.path.join(FLAGS.output_dir, save_file = os.path.join(FLAGS.output_dir,
'results.json') if FLAGS.save_results else None 'results.json') if FLAGS.save_results else None
if FLAGS.slice_infer:
detector.predict_image_slice(
img_list,
FLAGS.slice_size,
FLAGS.overlap_ratio,
FLAGS.combine_method,
FLAGS.match_threshold,
FLAGS.match_metric,
save_file=save_file)
else:
detector.predict_image( detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file) img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
......
...@@ -16,6 +16,7 @@ import time ...@@ -16,6 +16,7 @@ import time
import os import os
import ast import ast
import argparse import argparse
import numpy as np
def argsparser(): def argsparser():
...@@ -161,7 +162,39 @@ def argsparser(): ...@@ -161,7 +162,39 @@ def argsparser():
type=bool, type=bool,
default=False, default=False,
help="Whether save detection result to file using coco format") help="Whether save detection result to file using coco format")
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
parser.add_argument(
"--match_threshold",
type=float,
default=0.6,
help="Combine method matching threshold.")
parser.add_argument(
"--match_metric",
type=str,
default='iou',
help="Combine method matching metric, choose in ['iou', 'ios'].")
return parser return parser
...@@ -288,3 +321,68 @@ def get_current_memory_mb(): ...@@ -288,3 +321,68 @@ def get_current_memory_mb():
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024. gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
final_boxes = []
for c in range(num_classes):
idxs = bboxs[:, 0] == c
if np.count_nonzero(idxs) == 0: continue
r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
return final_boxes
def nms(dets, match_threshold=0.6, match_metric='iou'):
""" Apply NMS to avoid detecting too many overlapping bounding boxes.
Args:
dets: shape [N, 5], [score, x1, y1, x2, y2]
match_metric: 'iou' or 'ios'
match_threshold: overlap thresh for match metric.
"""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int)
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
if match_metric == 'iou':
union = iarea + areas[j] - inter
match_value = inter / union
elif match_metric == 'ios':
smaller = min(iarea, areas[j])
match_value = inter / smaller
else:
raise ValueError()
if match_value >= match_threshold:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
return dets
...@@ -253,3 +253,126 @@ class COCODataSet(DetDataset): ...@@ -253,3 +253,126 @@ class COCODataSet(DetDataset):
empty_records = self._sample_empty(empty_records, len(records)) empty_records = self._sample_empty(empty_records, len(records))
records += empty_records records += empty_records
self.roidbs = records self.roidbs = records
@register
@serializable
class SlicedCOCODataSet(COCODataSet):
"""Sliced COCODataSet"""
def __init__(
self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.,
repeat=1,
sliced_size=[640, 640],
overlap_ratio=[0.25, 0.25], ):
super(SlicedCOCODataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
load_crowd=load_crowd,
allow_empty=allow_empty,
empty_ratio=empty_ratio,
repeat=repeat, )
self.sliced_size = sliced_size
self.overlap_ratio = overlap_ratio
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
assert anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
from pycocotools.coco import COCO
coco = COCO(anno_path)
img_ids = coco.getImgIds()
img_ids.sort()
cat_ids = coco.getCatIds()
records = []
empty_records = []
ct = 0
ct_sub = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in self.catid2clsid.items()
})
if 'annotations' not in coco.dataset:
self.load_image_only = True
logger.warning('Annotation file: {} does not contains ground truth '
'and load image information only.'.format(anno_path))
try:
import sahi
from sahi.slicing import slice_image
except Exception as e:
logger.error(
'sahi not found, plaese install sahi. '
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
)
raise e
sub_img_ids = 0
for img_id in img_ids:
img_anno = coco.loadImgs([img_id])[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
logger.warning('Illegal width: {} or height: {} in annotation, '
'and im_id: {} will be ignored'.format(
im_w, im_h, img_id))
continue
slice_image_result = sahi.slicing.slice_image(
image=im_path,
slice_height=self.sliced_size[0],
slice_width=self.sliced_size[1],
overlap_height_ratio=self.overlap_ratio[0],
overlap_width_ratio=self.overlap_ratio[1])
sub_img_num = len(slice_image_result)
for _ind in range(sub_img_num):
im = slice_image_result.images[_ind]
coco_rec = {
'image': im,
'im_id': np.array([sub_img_ids + _ind]),
'h': im.shape[0],
'w': im.shape[1],
'ori_im_id': np.array([img_id]),
'st_pix': np.array(
slice_image_result.starting_pixels[_ind],
dtype=np.float32),
'is_last': 1 if _ind == sub_img_num - 1 else 0,
} if 'image' in self.data_fields else {}
records.append(coco_rec)
ct_sub += sub_img_num
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.info('{} samples and slice to {} sub_samples in file {}'.format(
ct, ct_sub, anno_path))
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
...@@ -206,6 +206,55 @@ class ImageFolder(DetDataset): ...@@ -206,6 +206,55 @@ class ImageFolder(DetDataset):
self.image_dir = images self.image_dir = images
self.roidbs = self._load_images() self.roidbs = self._load_images()
def set_slice_images(self,
images,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25]):
self.image_dir = images
ori_records = self._load_images()
try:
import sahi
from sahi.slicing import slice_image
except Exception as e:
logger.error(
'sahi not found, plaese install sahi. '
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
)
raise e
sub_img_ids = 0
ct = 0
ct_sub = 0
records = []
for i, ori_rec in enumerate(ori_records):
im_path = ori_rec['im_file']
slice_image_result = sahi.slicing.slice_image(
image=im_path,
slice_height=slice_size[0],
slice_width=slice_size[1],
overlap_height_ratio=overlap_ratio[0],
overlap_width_ratio=overlap_ratio[1])
sub_img_num = len(slice_image_result)
for _ind in range(sub_img_num):
im = slice_image_result.images[_ind]
rec = {
'image': im,
'im_id': np.array([sub_img_ids + _ind]),
'h': im.shape[0],
'w': im.shape[1],
'ori_im_id': np.array([ori_rec['im_id'][0]]),
'st_pix': np.array(
slice_image_result.starting_pixels[_ind],
dtype=np.float32),
'is_last': 1 if _ind == sub_img_num - 1 else 0,
} if 'image' in self.data_fields else {}
records.append(rec)
ct_sub += sub_img_num
ct += 1
print('{} samples and slice to {} sub_samples'.format(ct, ct_sub))
self.roidbs = records
def get_label_list(self): def get_label_list(self):
# Only VOC dataset needs label list in ImageFold # Only VOC dataset needs label list in ImageFold
return self.anno_path return self.anno_path
......
...@@ -122,12 +122,15 @@ class Decode(BaseOperator): ...@@ -122,12 +122,15 @@ class Decode(BaseOperator):
sample['image'] = f.read() sample['image'] = f.read()
sample.pop('im_file') sample.pop('im_file')
try:
im = sample['image'] im = sample['image']
data = np.frombuffer(im, dtype='uint8') data = np.frombuffer(im, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
if 'keep_ori_im' in sample and sample['keep_ori_im']: if 'keep_ori_im' in sample and sample['keep_ori_im']:
sample['ori_image'] = im sample['ori_image'] = im
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
except:
im = sample['image']
sample['image'] = im sample['image'] = im
if 'h' not in sample: if 'h' not in sample:
......
...@@ -45,6 +45,7 @@ from ppdet.data.source.category import get_categories ...@@ -45,6 +45,7 @@ from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from ppdet.utils.fuse_utils import fuse_conv_bn from ppdet.utils.fuse_utils import fuse_conv_bn
from ppdet.utils import profiler from ppdet.utils import profiler
from ppdet.modeling.post_process import multiclass_nms
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback
from .export_utils import _dump_infer_config, _prune_input_spec from .export_utils import _dump_infer_config, _prune_input_spec
...@@ -617,6 +618,194 @@ class Trainer(object): ...@@ -617,6 +618,194 @@ class Trainer(object):
with paddle.no_grad(): with paddle.no_grad():
self._eval_with_loader(self.loader) self._eval_with_loader(self.loader)
def _eval_with_loader_slice(self,
loader,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25],
combine_method='nms',
match_threshold=0.6,
match_metric='iou'):
sample_num = 0
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
merged_bboxs = []
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
# forward
if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
outs = self.model(data)
else:
outs = self.model(data)
shift_amount = data['st_pix']
outs['bbox'][:, 2:4] = outs['bbox'][:, 2:4] + shift_amount
outs['bbox'][:, 4:6] = outs['bbox'][:, 4:6] + shift_amount
merged_bboxs.append(outs['bbox'])
if data['is_last'] > 0:
# merge matching predictions
merged_results = {'bbox': []}
if combine_method == 'nms':
final_boxes = multiclass_nms(
np.concatenate(merged_bboxs), self.cfg.num_classes,
match_threshold, match_metric)
merged_results['bbox'] = np.concatenate(final_boxes)
elif combine_method == 'concat':
merged_results['bbox'] = np.concatenate(merged_bboxs)
else:
raise ValueError(
"Now only support 'nms' or 'concat' to fuse detection results."
)
merged_results['im_id'] = np.array([[0]])
merged_results['bbox_num'] = np.array(
[len(merged_results['bbox'])])
merged_bboxs = []
data['im_id'] = data['ori_im_id']
# update metrics
for metric in self._metrics:
metric.update(data, merged_results)
# multi-scale inputs: all inputs have same im_id
if isinstance(data, typing.Sequence):
sample_num += data[0]['im_id'].numpy().shape[0]
else:
sample_num += data['im_id'].numpy().shape[0]
self._compose_callback.on_step_end(self.status)
self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic
# accumulate metric to log out
for metric in self._metrics:
metric.accumulate()
metric.log()
self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times
self._reset_metrics()
def evaluate_slice(self,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25],
combine_method='nms',
match_threshold=0.6,
match_metric='iou'):
with paddle.no_grad():
self._eval_with_loader_slice(self.loader, slice_size, overlap_ratio,
combine_method, match_threshold,
match_metric)
def slice_predict(self,
images,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25],
combine_method='nms',
match_threshold=0.6,
match_metric='iou',
draw_threshold=0.5,
output_dir='output',
save_results=False):
self.dataset.set_slice_images(images, slice_size, overlap_ratio)
loader = create('TestReader')(self.dataset, 0)
imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno()
clsid2catid, catid2name = get_categories(
self.cfg.metric, anno_file=anno_file)
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = [] # all images
merged_bboxs = [] # single image
for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id
# forward
outs = self.model(data)
outs['bbox'] = outs['bbox'].numpy() # only in test mode
shift_amount = data['st_pix']
outs['bbox'][:, 2:4] = outs['bbox'][:, 2:4] + shift_amount.numpy()
outs['bbox'][:, 4:6] = outs['bbox'][:, 4:6] + shift_amount.numpy()
merged_bboxs.append(outs['bbox'])
if data['is_last'] > 0:
# merge matching predictions
merged_results = {'bbox': []}
if combine_method == 'nms':
final_boxes = multiclass_nms(
np.concatenate(merged_bboxs), self.cfg.num_classes,
match_threshold, match_metric)
merged_results['bbox'] = np.concatenate(final_boxes)
elif combine_method == 'concat':
merged_results['bbox'] = np.concatenate(merged_bboxs)
else:
raise ValueError(
"Now only support 'nms' or 'concat' to fuse detection results."
)
merged_results['im_id'] = np.array([[0]])
merged_results['bbox_num'] = np.array(
[len(merged_results['bbox'])])
merged_bboxs = []
data['im_id'] = data['ori_im_id']
for key in ['im_shape', 'scale_factor', 'im_id']:
if isinstance(data, typing.Sequence):
outs[key] = data[0][key]
else:
outs[key] = data[key]
for key, value in merged_results.items():
if hasattr(value, 'numpy'):
merged_results[key] = value.numpy()
results.append(merged_results)
# visualize results
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = ImageOps.exif_transpose(image)
self.status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None
mask_res, segm_res, keypoint_res = None, None, None
image = visualize_results(
image, bbox_res, mask_res, segm_res, keypoint_res,
int(im_id), catid2name, draw_threshold)
self.status['result_image'] = np.array(image.copy())
if self._compose_callback:
self._compose_callback.on_step_end(self.status)
# save image with detection
save_name = self._get_save_image_name(output_dir, image_path)
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
start = end
def predict(self, def predict(self,
images, images,
draw_threshold=0.5, draw_threshold=0.5,
......
...@@ -617,8 +617,23 @@ class SparsePostProcess(object): ...@@ -617,8 +617,23 @@ class SparsePostProcess(object):
return bbox_pred, bbox_num return bbox_pred, bbox_num
def nms(dets, thresh): def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
"""Apply classic DPM-style greedy NMS.""" final_boxes = []
for c in range(num_classes):
idxs = bboxs[:, 0] == c
if np.count_nonzero(idxs) == 0: continue
r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
return final_boxes
def nms(dets, match_threshold=0.6, match_metric='iou'):
""" Apply NMS to avoid detecting too many overlapping bounding boxes.
Args:
dets: shape [N, 5], [score, x1, y1, x2, y2]
match_metric: 'iou' or 'ios'
match_threshold: overlap thresh for match metric.
"""
if dets.shape[0] == 0: if dets.shape[0] == 0:
return dets[[], :] return dets[[], :]
scores = dets[:, 0] scores = dets[:, 0]
...@@ -626,25 +641,12 @@ def nms(dets, thresh): ...@@ -626,25 +641,12 @@ def nms(dets, thresh):
y1 = dets[:, 2] y1 = dets[:, 2]
x2 = dets[:, 3] x2 = dets[:, 3]
y2 = dets[:, 4] y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1) areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1] order = scores.argsort()[::-1]
ndets = dets.shape[0] ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int) suppressed = np.zeros((ndets), dtype=np.int)
# nominal indices
# _i, _j
# sorted indices
# i, j
# temp variables for box i's (the box currently under consideration)
# ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
# xx1, yy1, xx2, yy2
# w, h
# inter, ovr
for _i in range(ndets): for _i in range(ndets):
i = order[_i] i = order[_i]
if suppressed[i] == 1: if suppressed[i] == 1:
...@@ -665,8 +667,15 @@ def nms(dets, thresh): ...@@ -665,8 +667,15 @@ def nms(dets, thresh):
w = max(0.0, xx2 - xx1 + 1) w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1) h = max(0.0, yy2 - yy1 + 1)
inter = w * h inter = w * h
ovr = inter / (iarea + areas[j] - inter) if match_metric == 'iou':
if ovr >= thresh: union = iarea + areas[j] - inter
match_value = inter / union
elif match_metric == 'ios':
smaller = min(iarea, areas[j])
match_value = inter / smaller
else:
raise ValueError()
if match_value >= match_threshold:
suppressed[j] = 1 suppressed[j] = 1
keep = np.where(suppressed == 0)[0] keep = np.where(suppressed == 0)[0]
dets = dets[keep, :] dets = dets[keep, :]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
import json
import numpy as np
import argparse
def median(data):
data.sort()
mid = len(data) // 2
median = (data[mid] + data[~mid]) / 2
return median
def draw_distribution(width, height, out_path):
w_bins = int((max(width) - min(width)) // 10)
h_bins = int((max(height) - min(height)) // 10)
plt.figure()
plt.subplot(221)
plt.hist(width, bins=w_bins, color='green')
plt.xlabel('Width rate *1000')
plt.ylabel('number')
plt.title('Distribution of Width')
plt.subplot(222)
plt.hist(height, bins=h_bins, color='blue')
plt.xlabel('Height rate *1000')
plt.title('Distribution of Height')
plt.savefig(out_path)
print(f'Distribution saved as {out_path}')
plt.show()
def get_ratio_infos(jsonfile, out_img):
allannjson = json.load(open(jsonfile, 'r'))
be_im_id = 1
be_im_w = []
be_im_h = []
ratio_w = []
ratio_h = []
images = allannjson['images']
for i, ann in enumerate(allannjson['annotations']):
if ann['iscrowd']:
continue
x0, y0, w, h = ann['bbox'][:]
if be_im_id == ann['image_id']:
be_im_w.append(w)
be_im_h.append(h)
else:
im_w = images[be_im_id - 1]['width']
im_h = images[be_im_id - 1]['height']
im_m_w = np.mean(be_im_w)
im_m_h = np.mean(be_im_h)
dis_w = im_m_w / im_w
dis_h = im_m_h / im_h
ratio_w.append(dis_w)
ratio_h.append(dis_h)
be_im_id = ann['image_id']
be_im_w = [w]
be_im_h = [h]
im_w = images[be_im_id - 1]['width']
im_h = images[be_im_id - 1]['height']
im_m_w = np.mean(be_im_w)
im_m_h = np.mean(be_im_h)
dis_w = im_m_w / im_w
dis_h = im_m_h / im_h
ratio_w.append(dis_w)
ratio_h.append(dis_h)
mid_w = median(ratio_w)
mid_h = median(ratio_h)
ratio_w = [i * 1000 for i in ratio_w]
ratio_h = [i * 1000 for i in ratio_h]
print(f'Median of ratio_w is {mid_w}')
print(f'Median of ratio_h is {mid_h}')
print('all_img with box: ', len(ratio_h))
print('all_ann: ', len(allannjson['annotations']))
draw_distribution(ratio_w, ratio_h, out_img)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--json_path', type=str, default=None, help="Dataset json path.")
parser.add_argument(
'--out_img',
type=str,
default='box_distribution.jpg',
help="Name of distibution img.")
args = parser.parse_args()
get_ratio_infos(args.json_path, args.out_img)
if __name__ == "__main__":
main()
...@@ -83,6 +83,40 @@ def parse_args(): ...@@ -83,6 +83,40 @@ def parse_args():
default=False, default=False,
help="Enable auto mixed precision eval.") help="Enable auto mixed precision eval.")
# for smalldet slice_infer
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
parser.add_argument(
"--match_threshold",
type=float,
default=0.6,
help="Combine method matching threshold.")
parser.add_argument(
"--match_metric",
type=str,
default='iou',
help="Combine method matching metric, choose in ['iou', 'ios'].")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -109,6 +143,14 @@ def run(FLAGS, cfg): ...@@ -109,6 +143,14 @@ def run(FLAGS, cfg):
trainer.load_weights(cfg.weights) trainer.load_weights(cfg.weights)
# training # training
if FLAGS.slice_infer:
trainer.evaluate_slice(
slice_size=FLAGS.slice_size,
overlap_ratio=FLAGS.overlap_ratio,
combine_method=FLAGS.combine_method,
match_threshold=FLAGS.match_threshold,
match_metric=FLAGS.match_metric)
else:
trainer.evaluate() trainer.evaluate()
......
...@@ -81,6 +81,39 @@ def parse_args(): ...@@ -81,6 +81,39 @@ def parse_args():
type=bool, type=bool,
default=False, default=False,
help="Whether to save inference results to output_dir.") help="Whether to save inference results to output_dir.")
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
parser.add_argument(
"--match_threshold",
type=float,
default=0.6,
help="Combine method matching threshold.")
parser.add_argument(
"--match_metric",
type=str,
default='iou',
help="Combine method matching metric, choose in ['iou', 'ios'].")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -127,6 +160,18 @@ def run(FLAGS, cfg): ...@@ -127,6 +160,18 @@ def run(FLAGS, cfg):
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
# inference # inference
if FLAGS.slice_infer:
trainer.slice_predict(
images,
slice_size=FLAGS.slice_size,
overlap_ratio=FLAGS.overlap_ratio,
combine_method=FLAGS.combine_method,
match_threshold=FLAGS.match_threshold,
match_metric=FLAGS.match_metric,
draw_threshold=FLAGS.draw_threshold,
output_dir=FLAGS.output_dir,
save_results=FLAGS.save_results)
else:
trainer.predict( trainer.predict(
images, images,
draw_threshold=FLAGS.draw_threshold, draw_threshold=FLAGS.draw_threshold,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from tqdm import tqdm
def slice_data(image_dir, dataset_json_path, output_dir, slice_size,
overlap_ratio):
try:
from sahi.scripts.slice_coco import slice
except Exception as e:
raise RuntimeError(
'Unable to use sahi to slice images, please install sahi, for example: `pip install sahi`, see https://github.com/obss/sahi'
)
tqdm.write(
f" slicing for slice_size={slice_size}, overlap_ratio={overlap_ratio}")
slice(
image_dir=image_dir,
dataset_json_path=dataset_json_path,
output_dir=output_dir,
slice_size=slice_size,
overlap_ratio=overlap_ratio, )
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_dir', type=str, default=None, help="The image folder path.")
parser.add_argument(
'--json_path', type=str, default=None, help="Dataset json path.")
parser.add_argument(
'--output_dir', type=str, default=None, help="Output dir.")
parser.add_argument(
'--slice_size', type=int, default=500, help="slice_size")
parser.add_argument(
'--overlap_ratio', type=float, default=0.25, help="overlap_ratio")
args = parser.parse_args()
slice_data(args.image_dir, args.json_path, args.output_dir, args.slice_size,
args.overlap_ratio)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册