未验证 提交 fb2d5549 编写于 作者: W wangxinxin08 提交者: GitHub

fix ppyoloe_r/fcosr reader while eval and refine docs (#7330)

上级 314019c0
......@@ -51,7 +51,7 @@ ${DOTA_ROOT}
x1 y1 x2 y2 x3 y3 x4 y4 class_name difficult
```
### 单尺度切图
#### 单尺度切图
DOTA数据集分辨率较高,因此一般在训练和测试之前对图像进行离线切图,使用单尺度进行切图可以使用以下命令:
``` bash
# 对于有标注的数据进行切图
......@@ -75,7 +75,7 @@ python configs/rotate/tools/prepare_data.py \
```
### 多尺度切图
#### 多尺度切图
使用多尺度进行切图可以使用以下命令:
``` bash
# 对于有标注的数据进行切图
......@@ -98,6 +98,21 @@ python configs/rotate/tools/prepare_data.py \
--image_only
```
### 自定义数据集
旋转框使用标准COCO数据格式,你可以将你的数据集转换成COCO格式以训练模型。COCO标准数据格式的标注信息中包含以下信息:
``` python
'annotations': [
{
'id': 2083, 'category_id': 9, 'image_id': 9008,
'bbox': [x, y, w, h], # 水平框标注
'segmentation': [[x1, y1, x2, y2, x3, y3, x4, y4]], # 旋转框标注
...
}
...
]
```
**需要注意的是`bbox`的标注是水平框标注,`segmentation`为旋转框四个点的标注(顺时针或逆时针均可)。在旋转框训练时`bbox`是可以缺省,一般推荐根据旋转框标注`segmentation`生成。** 在PaddleDetection 2.4及之前的版本,`bbox`为旋转框标注[x, y, w, h, angle],`segmentation`缺省,**目前该格式已不再支持,请下载最新数据集或者转换成标准COCO格式**
## 安装依赖
旋转框检测模型需要依赖外部算子进行训练,评估等。Linux环境下,你可以执行以下命令进行编译安装
```
......
......@@ -51,7 +51,7 @@ For labeled data, each image corresponds to a txt file with the same name, and e
x1 y1 x2 y2 x3 y3 x4 y4 class_name difficult
```
### Slicing data with single scale
#### Slicing data with single scale
The image resolution of DOTA dataset is relatively high, so we usually slice the images before training and testing. To slice the images with a single scale, you can use the command below
``` bash
# slicing labeled data
......@@ -74,7 +74,7 @@ python configs/rotate/tools/prepare_data.py \
```
### Slicing data with multi scale
#### Slicing data with multi scale
To slice the images with multiple scales, you can use the command below
``` bash
# slicing labeled data
......@@ -96,6 +96,20 @@ python configs/rotate/tools/prepare_data.py \
--image_only
```
### Custom Dataset
Rotated object detction uses the standard COCO data format, and you can convert your dataset to COCO format to train the model. The annotations of standard COCO format contains the following information
``` python
'annotations': [
{
'id': 2083, 'category_id': 9, 'image_id': 9008,
'bbox': [x, y, w, h], # horizontal bouding box
'segmentation': [[x1, y1, x2, y2, x3, y3, x4, y4]], # rotated bounding box
...
}
...
]
```
**It should be noted that `bbox` is the horizontal bouding box, and `segmentation` is four points of rotated bounding box (clockwise or counterclockwise). The `bbox` can be empty when training rotated object detector, and it is recommended to generate `bbox` according to `segmentation`**. In PaddleDetection 2.4 and earlier versions, `bbox` represents the rotated bounding box [x, y, w, h, angle] and `segmentation` is empty. **But this format is no longer supported after PaddleDetection 2.5, please download the latest dataset or convert to standard COCO format**.
## Installation
Models of rotated object detection depend on external operators for training, evaluation, etc. In Linux environment, you can execute the following command to compile and install.
```
......
......@@ -33,6 +33,7 @@ EvalReader:
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
collate_batch: false
TestReader:
sample_transforms:
......@@ -42,4 +43,4 @@ TestReader:
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 8
batch_size: 2
......@@ -11,7 +11,7 @@
- [引用](#引用)
## 简介
PP-YOLOE-R是一个高效的单阶段Anchor-free旋转框检测模型。基于PP-YOLOE, PP-YOLOE-R以极少的参数量和计算量为代价,引入了一系列有用的设计来提升检测精度。在DOTA 1.0数据集上,PP-YOLOE-R-l和PP-YOLOE-R-x在单尺度训练和测试的情况下分别达到了78.14和78.27 mAP,这超越了几乎所有的旋转框检测模型。通过多尺度训练和测试,PP-YOLOE-R-l和PP-YOLOE-R-x的检测精度进一步提升至80.02和80.73 mAP。在这种情况下,PP-YOLOE-R-x超越了所有的anchor-free方法并且和最先进的anchor-based的两阶段模型精度几乎相当。此外,PP-YOLOE-R-s和PP-YOLOE-R-m通过多尺度训练和测试可以达到79.42和79.71 mAP。考虑到这两个模型的参数量和计算量,其性能也非常卓越。在保持高精度的同时,PP-YOLOE-R避免使用特殊的算子,例如Deformable Convolution或Rotated RoI Align,以使其能轻松地部署在多种多样的硬件上。在1024x1024的输入分辨率下,PP-YOLOE-R-s/m/l/x在RTX 2080 Ti上使用TensorRT FP16分别能达到69.8/55.1/48.3/37.1 FPS,在Tesla V100上分别能达到114.5/86.8/69.7/50.7 FPS。更多细节可以参考我们的技术报告
PP-YOLOE-R是一个高效的单阶段Anchor-free旋转框检测模型。基于PP-YOLOE, PP-YOLOE-R以极少的参数量和计算量为代价,引入了一系列有用的设计来提升检测精度。在DOTA 1.0数据集上,PP-YOLOE-R-l和PP-YOLOE-R-x在单尺度训练和测试的情况下分别达到了78.14和78.27 mAP,这超越了几乎所有的旋转框检测模型。通过多尺度训练和测试,PP-YOLOE-R-l和PP-YOLOE-R-x的检测精度进一步提升至80.02和80.73 mAP。在这种情况下,PP-YOLOE-R-x超越了所有的anchor-free方法并且和最先进的anchor-based的两阶段模型精度几乎相当。此外,PP-YOLOE-R-s和PP-YOLOE-R-m通过多尺度训练和测试可以达到79.42和79.71 mAP。考虑到这两个模型的参数量和计算量,其性能也非常卓越。在保持高精度的同时,PP-YOLOE-R避免使用特殊的算子,例如Deformable Convolution或Rotated RoI Align,以使其能轻松地部署在多种多样的硬件上。在1024x1024的输入分辨率下,PP-YOLOE-R-s/m/l/x在RTX 2080 Ti上使用TensorRT FP16分别能达到69.8/55.1/48.3/37.1 FPS,在Tesla V100上分别能达到114.5/86.8/69.7/50.7 FPS。更多细节可以参考我们的[**技术报告**](https://arxiv.org/abs/2211.02386)
<div align="center">
<img src="../../../docs/images/ppyoloe_r_map_fps.png" width=500 />
......@@ -26,22 +26,23 @@ PP-YOLOE-R相较于PP-YOLOE做了以下几点改动:
## 模型库
| 模型 | Backbone | mAP | V100 TRT FP16 (FPS) | RTX 2080 Ti TRT FP16 (FPS) |学习率策略 | 角度表示 | 数据增广 | GPU数目 | 每GPU图片数目 | 模型下载 | 配置文件 |
|:---:|:--------:|:----:|:--------------------:|:------------:|:--------------------:|:-----:|:--------:|:-------:|:------:|:-----------:|:------:|
| PP-YOLOE-R-s | CRN-s | 73.82 | 114.5 | 69.8 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml) |
| PP-YOLOE-R-s | CRN-s | 79.42 | 114.5 | 69.8 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota_ms.yml) |
| PP-YOLOE-R-m | CRN-m | 77.64 | 86.8 | 55.1 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota.yml) |
| PP-YOLOE-R-m | CRN-m | 79.71 | 86.8 | 55.1 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota_ms.yml) |
| PP-YOLOE-R-l | CRN-l | 78.14 | 69.7 | 48.3 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml) |
| PP-YOLOE-R-l | CRN-l | 80.02 | 69.7 | 48.3 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota_ms.yml) |
| PP-YOLOE-R-x | CRN-x | 78.28 | 50.7 | 37.1 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota.yml) |
| PP-YOLOE-R-x | CRN-x | 80.73 | 50.7 | 37.1 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota_ms.yml) |
| 模型 | Backbone | mAP | V100 TRT FP16 (FPS) | RTX 2080 Ti TRT FP16 (FPS) | Params (M) | FLOPs (G) | 学习率策略 | 角度表示 | 数据增广 | GPU数目 | 每GPU图片数目 | 模型下载 | 配置文件 |
|:---:|:--------:|:----:|:--------------------:|:------------------------:|:----------:|:---------:|:--------:|:----------:|:-------:|:------:|:-----------:|:--------:|:------:|
| PP-YOLOE-R-s | CRN-s | 73.82 | 114.5 | 69.8 | 8.09 | 43.46 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml) |
| PP-YOLOE-R-s | CRN-s | 79.42 | 114.5 | 69.8 | 8.09 | 43.46 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota_ms.yml) |
| PP-YOLOE-R-m | CRN-m | 77.64 | 86.8 | 55.1 | 23.96 |127.00 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota.yml) |
| PP-YOLOE-R-m | CRN-m | 79.71 | 86.8 | 55.1 | 23.96 |127.00 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota_ms.yml) |
| PP-YOLOE-R-l | CRN-l | 78.14 | 69.7 | 48.3 | 53.29 |281.65 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml) |
| PP-YOLOE-R-l | CRN-l | 80.02 | 69.7 | 48.3 | 53.29 |281.65 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota_ms.yml) |
| PP-YOLOE-R-x | CRN-x | 78.28 | 50.7 | 37.1 | 100.27|529.82 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota.yml) |
| PP-YOLOE-R-x | CRN-x | 80.73 | 50.7 | 37.1 | 100.27|529.82 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota_ms.yml) |
**注意:**
- 如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)** 调整学习率。
- 模型库中的模型默认使用单尺度训练单尺度测试。如果数据增广一栏标明MS,意味着使用多尺度训练和多尺度测试。如果数据增广一栏标明RR,意味着使用RandomRotate数据增广进行训练。
- CRN表示在PP-YOLOE中提出的CSPRepResNet
- PP-YOLOE-R的参数量和计算量是在重参数化之后计算得到,输入图像的分辨率为1024x1024
- 速度测试使用TensorRT 8.2.3在DOTA测试集中测试2000张图片计算平均值得到。参考速度测试以复现[速度测试](#速度测试)
## 使用说明
......@@ -81,7 +82,7 @@ zip -r submit.zip submit
```
### 速度测试
速度测试需要确保**TensorRT版本大于8.2, PaddlePaddle版本大于2.4.0rc0**。使用Paddle Inference且使用TensorRT进行测速,执行以下命令:
可以使用Paddle模式或者Paddle-TRT模式进行测速。当使用Paddle-TRT模式测速时,需要确保**TensorRT版本大于8.2, PaddlePaddle版本大于2.4.0rc0**。使用Paddle-TRT进行测速,可以执行以下命令:
``` bash
# 导出模型
......@@ -90,10 +91,18 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
# 速度测试
CUDA_VISIBLE_DEVICES=0 python configs/rotate/tools/inference_benchmark.py --model_dir output_inference/ppyoloe_r_crn_l_3x_dota/ --image_dir /path/to/dota/test/dir --run_mode trt_fp16
```
当只使用Paddle进行测速,可以执行以下命令:
``` bash
# 导出模型
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams
# 速度测试
CUDA_VISIBLE_DEVICES=0 python configs/rotate/tools/inference_benchmark.py --model_dir output_inference/ppyoloe_r_crn_l_3x_dota/ --image_dir /path/to/dota/test/dir --run_mode paddle
```
## 预测部署
**使用Paddle Inference但不使用TensorRT**进行部署,执行以下命令:
**使用Paddle**进行部署,执行以下命令:
``` bash
# 导出模型
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams
......@@ -102,7 +111,7 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_dir=output_inference/ppyoloe_r_crn_l_3x_dota --run_mode=paddle --device=gpu
```
**使用Paddle Inference且使用TensorRT**进行部署,执行以下命令:
**使用Paddle-TRT进行部署**,执行以下命令:
```
# 导出模型
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams trt=True
......@@ -112,7 +121,7 @@ python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_di
```
**注意:**
- 使用Paddle-TRT使用确保PaddlePaddle版本大于2.4.0rc且TensorRT版本大于8.2.
- 使用Paddle-TRT使用确保**PaddlePaddle版本大于2.4.0rc且TensorRT版本大于8.2**.
## 附录
......@@ -131,6 +140,13 @@ PP-YOLOE-R消融实验
## 引用
```
@article{wang2022pp,
title={PP-YOLOE-R: An Efficient Anchor-Free Rotated Object Detector},
author={Wang, Xinxin and Wang, Guanzhong and Dang, Qingqing and Liu, Yi and Hu, Xiaoguang and Yu, Dianhai},
journal={arXiv preprint arXiv:2211.02386},
year={2022}
}
@article{xu2022pp,
title={PP-YOLOE: An evolved version of YOLO},
author={Xu, Shangliang and Wang, Xinxin and Lv, Wenyu and Chang, Qinyao and Cui, Cheng and Deng, Kaipeng and Wang, Guanzhong and Dang, Qingqing and Wei, Shengyu and Du, Yuning and others},
......
......@@ -11,7 +11,7 @@ English | [简体中文](README.md)
- [Citations](#Citations)
## Introduction
PP-YOLOE-R is an efficient anchor-free rotated object detector. Based on PP-YOLOE, PP-YOLOE-R introduces a bag of useful tricks to improve detection precision at the expense of marginal parameters and computations.PP-YOLOE-R-l and PP-YOLOE-R-x achieve 78.14 and 78.27 mAP respectively on DOTA 1.0 dataset with single-scale training and testing, which outperform almost all other rotated object detectors. With multi-scale training and testing, the detection precision of PP-YOLOE-R-l and PP-YOLOE-R-x is further improved to 80.02 and 80.73 mAP. In this case, PP-YOLOE-R-x surpasses all anchor-free methods and demonstrates competitive performance to state-of-the-art anchor-based two-stage model. Moreover, PP-YOLOE-R-s and PP-YOLOE-R-m can achieve 79.42 and 79.71 mAP with multi-scale training and testing, which is an excellent result considering the parameters and GLOPS of these two models. While maintaining high precision, PP-YOLOE-R avoids using special operators, such as Deformable Convolution or Rotated RoI Align, to be deployed friendly on various hardware. At the input resolution of 1024$\times$1024, PP-YOLOE-R-s/m/l/x can reach 69.8/55.1/48.3/37.1 FPS on RTX 2080 Ti and 114.5/86.8/69.7/50.7 FPS on Tesla V100 GPU with TensorRT and FP16-precision. For more details, please refer to our technical report.
PP-YOLOE-R is an efficient anchor-free rotated object detector. Based on PP-YOLOE, PP-YOLOE-R introduces a bag of useful tricks to improve detection precision at the expense of marginal parameters and computations.PP-YOLOE-R-l and PP-YOLOE-R-x achieve 78.14 and 78.27 mAP respectively on DOTA 1.0 dataset with single-scale training and testing, which outperform almost all other rotated object detectors. With multi-scale training and testing, the detection precision of PP-YOLOE-R-l and PP-YOLOE-R-x is further improved to 80.02 and 80.73 mAP. In this case, PP-YOLOE-R-x surpasses all anchor-free methods and demonstrates competitive performance to state-of-the-art anchor-based two-stage model. Moreover, PP-YOLOE-R-s and PP-YOLOE-R-m can achieve 79.42 and 79.71 mAP with multi-scale training and testing, which is an excellent result considering the parameters and GLOPS of these two models. While maintaining high precision, PP-YOLOE-R avoids using special operators, such as Deformable Convolution or Rotated RoI Align, to be deployed friendly on various hardware. At the input resolution of 1024$\times$1024, PP-YOLOE-R-s/m/l/x can reach 69.8/55.1/48.3/37.1 FPS on RTX 2080 Ti and 114.5/86.8/69.7/50.7 FPS on Tesla V100 GPU with TensorRT and FP16-precision. For more details, please refer to our [**technical report**](https://arxiv.org/abs/2211.02386).
<div align="center">
<img src="../../../docs/images/ppyoloe_r_map_fps.png" width=500 />
......@@ -25,22 +25,23 @@ Compared with PP-YOLOE, PP-YOLOE-R has made the following changes:
- [ProbIoU Loss](https://arxiv.org/abs/2106.06072)
## Model Zoo
| Model | Backbone | mAP | V100 TRT FP16 (FPS) | RTX 2080 Ti TRT FP16 (FPS) | Lr Scheduler | Angle | Aug | GPU Number | images/GPU | download | config |
|:---:|:--------:|:----:|:--------------------:|:------------:|:--------------------:|:-----:|:--------:|:-------:|:------:|:-----------:|:------:|
| PP-YOLOE-R-s | CRN-s | 73.82 | 114.5 | 69.8 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml) |
| PP-YOLOE-R-s | CRN-s | 79.42 | 114.5 | 69.8 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota_ms.yml) |
| PP-YOLOE-R-m | CRN-m | 77.64 | 86.8 | 55.1 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota.yml) |
| PP-YOLOE-R-m | CRN-m | 79.71 | 86.8 | 55.1 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota_ms.yml) |
| PP-YOLOE-R-l | CRN-l | 78.14 | 69.7 | 48.3 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml) |
| PP-YOLOE-R-l | CRN-l | 80.02 | 69.7 | 48.3 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota_ms.yml) |
| PP-YOLOE-R-x | CRN-x | 78.28 | 50.7 | 37.1 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota.yml) |
| PP-YOLOE-R-x | CRN-x | 80.73 | 50.7 | 37.1 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota_ms.yml) |
| Model | Backbone | mAP | V100 TRT FP16 (FPS) | RTX 2080 Ti TRT FP16 (FPS) | Params (M) | FLOPs (G) | Lr Scheduler | Angle | Aug | GPU Number | images/GPU | download | config |
|:-----:|:--------:|:----:|:-------------------:|:--------------------------:|:-----------:|:---------:|:--------:|:-----:|:---:|:----------:|:----------:|:--------:|:------:|
| PP-YOLOE-R-s | CRN-s | 73.82 | 114.5 | 69.8 | 8.09 | 43.46 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml) |
| PP-YOLOE-R-s | CRN-s | 79.42 | 114.5 | 69.8 | 8.09 | 43.46 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota_ms.yml) |
| PP-YOLOE-R-m | CRN-m | 77.64 | 86.8 | 55.1 | 23.96 |127.00 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota.yml) |
| PP-YOLOE-R-m | CRN-m | 79.71 | 86.8 | 55.1 | 23.96 |127.00 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_m_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_m_3x_dota_ms.yml) |
| PP-YOLOE-R-l | CRN-l | 78.14 | 69.7 | 48.3 | 53.29 |281.65 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml) |
| PP-YOLOE-R-l | CRN-l | 80.02 | 69.7 | 48.3 | 53.29 |281.65 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota_ms.yml) |
| PP-YOLOE-R-x | CRN-x | 78.28 | 50.7 | 37.1 | 100.27|529.82 | 3x | oc | RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota.yml) |
| PP-YOLOE-R-x | CRN-x | 80.73 | 50.7 | 37.1 | 100.27|529.82 | 3x | oc | MS+RR | 4 | 2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_x_3x_dota_ms.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rotate/ppyoloe_r/ppyoloe_r_crn_x_3x_dota_ms.yml) |
**Notes:**
- if **GPU number** or **mini-batch size** is changed, **learning rate** should be adjusted according to the formula **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)**.
- Models in model zoo is trained and tested with single scale by default. If `MS` is indicated in the data augmentation column, it means that multi-scale training and multi-scale testing are used. If `RR` is indicated in the data augmentation column, it means that RandomRotate data augmentation is used for training.
- CRN denotes CSPRepResNet proposed in PP-YOLOE
- The parameters and GLOPs of PP-YOLOE-R are calculated after re-parameterization, and the resolution of the input image is 1024x1024
- Speed ​​is calculated and averaged by testing 2000 images on the DOTA test dataset. Refer to [Speed testing](#Speed-testing) to reproduce the results.
## Getting Start
......@@ -82,7 +83,7 @@ zip -r submit.zip submit
### Speed testing
To test speed, make sure that **the version of TensorRT is larger than 8.2 and the version of PaddlePaddle is larger than 2.4.0rc**. Using Paddle Inference with TensorRT to test speed, run following command
You can use Paddle mode or Paddle-TRT mode for speed testing. When using Paddle-TRT for speed testing, make sure that **the version of TensorRT is larger than 8.2 and the version of PaddlePaddle is larger than 2.4.0rc**. Using Paddle-TRT to test speed, run following command
``` bash
# export inference model
......@@ -90,11 +91,20 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
# speed testing
CUDA_VISIBLE_DEVICES=0 python configs/rotate/tools/inference_benchmark.py --model_dir output_inference/ppyoloe_r_crn_l_3x_dota/ --image_dir /path/to/dota/test/dir --run_mode trt_fp16
```
Using Paddle to test speed, run following command
``` bash
# export inference model
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams
# speed testing
CUDA_VISIBLE_DEVICES=0 python configs/rotate/tools/inference_benchmark.py --model_dir output_inference/ppyoloe_r_crn_l_3x_dota/ --image_dir /path/to/dota/test/dir --run_mode paddle
```
## Deployment
**Using Paddle Inference without TensorRT** to for deployment, run following command
**Using Paddle** to for deployment, run following command
``` bash
# export inference model
......@@ -104,7 +114,7 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_dir=output_inference/ppyoloe_r_crn_l_3x_dota --run_mode=paddle --device=gpu
```
**Using Paddle Inference with TensorRT** to for deployment, run following command
**Using Paddle-TRT** to for deployment, run following command
``` bash
# export inference model
......@@ -113,6 +123,8 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
# inference single image
python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_dir=output_inference/ppyoloe_r_crn_l_3x_dota --run_mode=trt_fp16 --device=gpu
```
**Notes:**
- When using Paddle-TRT for speed testing, make sure that **the version of TensorRT is larger than 8.2 and the version of PaddlePaddle is larger than 2.4.0rc**
## Appendix
......@@ -129,6 +141,13 @@ Ablation experiments of PP-YOLOE-R
## Citations
```
@article{wang2022pp,
title={PP-YOLOE-R: An Efficient Anchor-Free Rotated Object Detector},
author={Wang, Xinxin and Wang, Guanzhong and Dang, Qingqing and Liu, Yi and Hu, Xiaoguang and Yu, Dianhai},
journal={arXiv preprint arXiv:2211.02386},
year={2022}
}
@article{xu2022pp,
title={PP-YOLOE: An evolved version of YOLO},
author={Xu, Shangliang and Wang, Xinxin and Lv, Wenyu and Chang, Qinyao and Cui, Cheng and Deng, Kaipeng and Wang, Guanzhong and Dang, Qingqing and Wei, Shengyu and Du, Yuning and others},
......
......@@ -33,6 +33,7 @@ EvalReader:
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
collate_batch: false
TestReader:
sample_transforms:
......@@ -42,4 +43,4 @@ TestReader:
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 8
batch_size: 2
......@@ -30,9 +30,9 @@ import paddle
import paddle.version as paddle_version
from paddle.inference import Config, create_predictor, PrecisionType, get_trt_runtime_version
TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
def check_version(version='2.2'):
err = "PaddlePaddle version {} or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
......@@ -83,8 +83,8 @@ def decode_image(im_file, im_info):
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
class Resize(object):
class Resize(object):
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
......@@ -128,8 +128,8 @@ class Resize(object):
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Permute(object):
class Permute(object):
def __init__(self, ):
super(Permute, self).__init__()
......@@ -137,6 +137,7 @@ class Permute(object):
im = im.transpose((2, 0, 1))
return im, im_info
class NormalizeImage(object):
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
......@@ -159,7 +160,6 @@ class NormalizeImage(object):
class PadStride(object):
def __init__(self, stride=0):
self.coarsest_stride = stride
......@@ -190,14 +190,27 @@ def preprocess(im, preprocess_ops):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, help='directory of inference model')
parser.add_argument('--run_mode', type=str, default='paddle', help='running mode')
parser.add_argument(
'--model_dir', type=str, help='directory of inference model')
parser.add_argument(
'--run_mode', type=str, default='paddle', help='running mode')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--image_dir', type=str, default='/paddle/data/DOTA_1024_ss/test1024/images', help='directory of test images')
parser.add_argument('--warmup_iter', type=int, default=5, help='num of warmup iters')
parser.add_argument('--total_iter', type=int, default=2000, help='num of total iters')
parser.add_argument('--log_iter', type=int, default=50, help='num of log interval')
parser.add_argument('--tuned_trt_shape_file', type=str, default='shape_range_info.pbtxt', help='dynamic shape range info')
parser.add_argument(
'--image_dir',
type=str,
default='/paddle/data/DOTA_1024_ss/test1024/images',
help='directory of test images')
parser.add_argument(
'--warmup_iter', type=int, default=5, help='num of warmup iters')
parser.add_argument(
'--total_iter', type=int, default=2000, help='num of total iters')
parser.add_argument(
'--log_iter', type=int, default=50, help='num of log interval')
parser.add_argument(
'--tuned_trt_shape_file',
type=str,
default='shape_range_info.pbtxt',
help='dynamic shape range info')
args = parser.parse_args()
return args
......@@ -207,11 +220,11 @@ def init_predictor(FLAGS):
yaml_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(yaml_file) as f:
yml_conf = yaml.safe_load(f)
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
......@@ -227,8 +240,11 @@ def init_predictor(FLAGS):
tuned_trt_shape_file = os.path.join(model_dir, FLAGS.tuned_trt_shape_file)
if run_mode in precision_map.keys():
if arch in TUNED_TRT_DYNAMIC_MODELS and not os.path.exists(tuned_trt_shape_file):
print('dynamic shape range info is saved in {}. After that, rerun the code'.format(tuned_trt_shape_file))
if arch in TUNED_TRT_DYNAMIC_MODELS and not os.path.exists(
tuned_trt_shape_file):
print(
'dynamic shape range info is saved in {}. After that, rerun the code'.
format(tuned_trt_shape_file))
config.collect_shape_range_info(tuned_trt_shape_file)
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
......@@ -239,8 +255,10 @@ def init_predictor(FLAGS):
use_calib_mode=False)
if yml_conf['use_dynamic_shape']:
if arch in TUNED_TRT_DYNAMIC_MODELS and os.path.exists(tuned_trt_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(tuned_trt_shape_file, True)
if arch in TUNED_TRT_DYNAMIC_MODELS and os.path.exists(
tuned_trt_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(tuned_trt_shape_file,
True)
else:
min_input_shape = {
'image': [batch_size, 3, 640, 640],
......@@ -254,9 +272,9 @@ def init_predictor(FLAGS):
'image': [batch_size, 3, 1024, 1024],
'scale_factor': [batch_size, 2]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
config.set_trt_dynamic_shape_info(
min_input_shape, max_input_shape, opt_input_shape)
# disable print log when predict
config.disable_glog_info()
# enable shared memory
......@@ -266,6 +284,7 @@ def init_predictor(FLAGS):
predictor = create_predictor(config)
return predictor, yml_conf
def create_preprocess_ops(yml_conf):
preprocess_ops = []
for op_info in yml_conf['Preprocess']:
......@@ -294,8 +313,10 @@ def create_inputs(image_files, preprocess_ops):
im_list.append(im)
im_info_list.append(im_info)
inputs['im_shape'] = np.stack([e['im_shape'] for e in im_info_list], axis=0).astype('float32')
inputs['scale_factor'] = np.stack([e['scale_factor'] for e in im_info_list], axis=0).astype('float32')
inputs['im_shape'] = np.stack(
[e['im_shape'] for e in im_info_list], axis=0).astype('float32')
inputs['scale_factor'] = np.stack(
[e['scale_factor'] for e in im_info_list], axis=0).astype('float32')
inputs['image'] = np.stack(im_list, axis=0).astype('float32')
return inputs
......@@ -318,7 +339,7 @@ def measure_speed(FLAGS):
for name in input_names:
input_tensor = predictor.get_input_handle(name)
input_tensor.copy_from_cpu(inputs[name])
paddle.device.cuda.synchronize()
# start running
start_time = time.perf_counter()
......@@ -334,7 +355,7 @@ def measure_speed(FLAGS):
f'fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
flush=True)
if (i + 1) == total_iter:
fps = (i + 1 - warmup_iter) / total_time
print(
......@@ -343,14 +364,10 @@ def measure_speed(FLAGS):
flush=True)
break
if __name__ == '__main__':
FLAGS = parse_args()
check_version('2.4')
check_trt_version('8.2')
if 'trt' in FLAGS.run_mode:
check_trt_version('8.2')
measure_speed(FLAGS)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册