未验证 提交 31a8da43 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph] ppyolo final (#2098)

* add ppyolo in dygraph

* fix bugs

* modify configs

* polish docs, modify ema and add configs

* modify link of model

* fix bugs in ema

* modify code according to review

* modify minicoco configs
上级 8059452f
English | [简体中文](README_cn.md)
# PP-YOLO
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
- [Future Work](#Future_Work)
- [Appendix](#Appendix)
## Introduction
[PP-YOLO](https://arxiv.org/abs/2007.12099) is a optimized model based on YOLOv3 in PaddleDetection,whose performance(mAP on COCO) and inference spped are better than [YOLOv4](https://arxiv.org/abs/2004.10934),PaddlePaddle 2.0.0rc1(available on pip now) or [Daily Version](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/Tables.html#whl-release) is required to run this PP-YOLO。
PP-YOLO reached mmAP(IoU=0.5:0.95) as 45.9% on COCO test-dev2017 dataset, and inference speed of FP32 on single V100 is 72.9 FPS, inference speed of FP16 with TensorRT on single V100 is 155.6 FPS.
<div align="center">
<img src="../../../docs/images/ppyolo_map_fps.png" width=500 />
</div>
PP-YOLO improved performance and speed of YOLOv3 with following methods:
- Better backbone: ResNet50vd-DCN
- Larger training batch size: 8 GPUs and mini-batch size as 24 on each GPU
- [Drop Block](https://arxiv.org/abs/1810.12890)
- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp)
- [IoU Loss](https://arxiv.org/pdf/1902.09630.pdf)
- [Grid Sensitive](https://arxiv.org/abs/2004.10934)
- [Matrix NMS](https://arxiv.org/pdf/2003.10152.pdf)
- [CoordConv](https://arxiv.org/abs/1807.03247)
- [Spatial Pyramid Pooling](https://arxiv.org/abs/1406.4729)
- Better ImageNet pretrain weights
## Model Zoo
### PP-YOLO
| Model | GPU number | images/GPU | backbone | input shape | Box AP<sup>val</sup> | Box AP<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:------------------------:|:----------:|:----------:|:----------:| :----------:| :------------------: | :-------------------: | :------------: | :---------------------: | :------: | :-----: |
| PP-YOLO | 8 | 24 | ResNet50vd | 608 | 44.8 | 45.2 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 608 | 45.3 | 45.9 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
**Notes:**
- PP-YOLO is trained on COCO train2017 dataset and evaluated on val2017 & test-dev2017 dataset,Box AP<sup>test</sup> is evaluation results of `mAP(IoU=0.5:0.95)`.
- PP-YOLO used 8 GPUs for training and mini-batch size as 24 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../../docs/FAQ.md).
- PP-YOLO inference speed is tesed on single Tesla V100 with batch size as 1, CUDA 10.2, CUDNN 7.5.1, TensorRT 5.1.2.2 in TensorRT mode.
- PP-YOLO FP32 inference speed testing uses inference model exported by `tools/export_model.py` and benchmarked by running `depoly/python/infer.py` with `--run_benchmark`. All testing results do not contains the time cost of data reading and post-processing(NMS), which is same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) in testing method.
- TensorRT FP16 inference speed testing exclude the time cost of bounding-box decoding(`yolo_box`) part comparing with FP32 testing above, which means that data reading, bounding-box decoding and post-processing(NMS) is excluded(test method same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) too)
## Getting Start
### 1. Training
Training PP-YOLO on 8 GPUs with following command(all commands should be run under PaddleDetection dygraph directory as default)
```bash
python -m paddle.distributed.launch --log_dir=./ppyolo_dygraph/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml &>ppyolo_dygraph.log 2>&1 &
```
### 2. Evaluation
Evaluating PP-YOLO on COCO val2017 dataset in single GPU with following commands:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=output/ppyolo_r50vd_dcn_1x_coco/model_final
```
For evaluation on COCO test-dev2017 dataset, `configs/ppyolo/ppyolo_test.yml` should be used, please download COCO test-dev2017 dataset from [COCO dataset download](https://cocodataset.org/#download) and decompress to pathes configured by `EvalReader.dataset` in `configs/ppyolo/ppyolo_test.yml` and run evaluation by following command:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=output/ppyolo_r50vd_dcn_1x_coco/model_final
```
Evaluation results will be saved in `bbox.json`, compress it into a `zip` package and upload to [COCO dataset evaluation](https://competitions.codalab.org/competitions/20794#participate) to evaluate.
**NOTE:** `configs/ppyolo/ppyolo_test.yml` is only used for evaluation on COCO test-dev2017 dataset, could not be used for training or COCO val2017 dataset evaluating.
### 3. Inference
Inference images in single GPU with following commands, use `--infer_img` to inference a single image and `--infer_dir` to inference all images in the directory.
```bash
# inference single image
CUDA_VISIBLE_DEVICES=0 python tools/infer.py configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --infer_img=../demo/000000014439_640x640.jpg
# inference all images in the directory
CUDA_VISIBLE_DEVICES=0 python tools/infer.py configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --infer_dir=../demo
```
### 4. Inferece deployment and benchmark
For inference deployment or benchmard, model exported with `tools/export_model.py` should be used and perform inference with Paddle inference library with following commands:
```bash
# export model, model will be save in output/ppyolo as default
python tools/export_model.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# inference with Paddle Inference library
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True
```
Benchmark testing for PP-YOLO uses model without data reading and post-processing(NMS), export model with `--exclude_nms` to prunce NMS for benchmark testing from mode with following commands:
```bash
# export model, --exclude_nms to prune NMS part, model will be save in output/ppyolo as default
python tools/export_model.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --exclude_nms
# FP32 benchmark
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True
# TensorRT FP16 benchmark
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True --run_mode=trt_fp16
```
## Future work
1. more PP-YOLO tiny model
2. PP-YOLO model with more backbones
## Appendix
Optimizing method and ablation experiments of PP-YOLO compared with YOLOv3.
| NO. | Model | Box AP<sup>val</sup> | Box AP<sup>test</sup> | Params(M) | FLOPs(G) | V100 FP32 FPS |
| :--: | :--------------------------- | :------------------: |:--------------------: | :-------: | :------: | :-----------: |
| A | YOLOv3-DarkNet53 | 38.9 | - | 59.13 | 65.52 | 58.2 |
| B | YOLOv3-ResNet50vd-DCN | 39.1 | - | 43.89 | 44.71 | 79.2 |
| C | B + LB + EMA + DropBlock | 41.4 | - | 43.89 | 44.71 | 79.2 |
| D | C + IoU Loss | 41.9 | - | 43.89 | 44.71 | 79.2 |
| E | D + IoU Aware | 42.5 | - | 43.90 | 44.71 | 74.9 |
| F | E + Grid Sensitive | 42.8 | - | 43.90 | 44.71 | 74.8 |
| G | F + Matrix NMS | 43.5 | - | 43.90 | 44.71 | 74.8 |
| H | G + CoordConv | 44.0 | - | 43.93 | 44.76 | 74.1 |
| I | H + SPP | 44.3 | 45.2 | 44.93 | 45.12 | 72.9 |
| J | I + Better ImageNet Pretrain | 44.8 | 45.2 | 44.93 | 45.12 | 72.9 |
| K | J + 2x Scheduler | 45.3 | 45.9 | 44.93 | 45.12 | 72.9 |
**Notes:**
- Performance and inference spedd are measure with input shape as 608
- All models are trained on COCO train2017 datast and evaluated on val2017 & test-dev2017 dataset,`Box AP` is evaluation results as `mAP(IoU=0.5:0.95)`.
- Inference speed is tested on single Tesla V100 with batch size as 1 following test method and environment configuration in benchmark above.
- [YOLOv3-DarkNet53](../yolov3/yolov3_darknet53_270e_coco.yml) with mAP as 39.0 is optimized YOLOv3 model in PaddleDetection,see [Model Zoo](../../../docs/MODEL_ZOO.md) for details.
简体中文 | [English](README.md)
# PP-YOLO 模型
## 内容
- [简介](#简介)
- [模型库与基线](#模型库与基线)
- [使用说明](#使用说明)
- [未来工作](#未来工作)
- [附录](#附录)
## 简介
[PP-YOLO](https://arxiv.org/abs/2007.12099)是PaddleDetection优化和改进的YOLOv3的模型,其精度(COCO数据集mAP)和推理速度均优于[YOLOv4](https://arxiv.org/abs/2004.10934)模型,要求使用PaddlePaddle 2.0.0rc1(可使用pip安装) 或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/Tables.html#whl-release)
PP-YOLO在[COCO](http://cocodataset.org) test-dev2017数据集上精度达到45.9%,在单卡V100上FP32推理速度为72.9 FPS, V100上开启TensorRT下FP16推理速度为155.6 FPS。
<div align="center">
<img src="../../../docs/images/ppyolo_map_fps.png" width=500 />
</div>
PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- 更优的骨干网络: ResNet50vd-DCN
- 更大的训练batch size: 8 GPUs,每GPU batch_size=24,对应调整学习率和迭代轮数
- [Drop Block](https://arxiv.org/abs/1810.12890)
- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp)
- [IoU Loss](https://arxiv.org/pdf/1902.09630.pdf)
- [Grid Sensitive](https://arxiv.org/abs/2004.10934)
- [Matrix NMS](https://arxiv.org/pdf/2003.10152.pdf)
- [CoordConv](https://arxiv.org/abs/1807.03247)
- [Spatial Pyramid Pooling](https://arxiv.org/abs/1406.4729)
- 更优的预训练模型
## 模型库
### PP-YOLO模型
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP<sup>val</sup> | Box AP<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------------: | :------------: | :---------------------: | :------: | :------: |
| PP-YOLO | 8 | 24 | ResNet50vd | 608 | 44.8 | 45.2 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 608 | 45.3 | 45.9 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
**注意:**
- PP-YOLO模型使用COCO数据集中train2017作为训练集,使用val2017和test-dev2017作为测试集,Box AP<sup>test</sup>`mAP(IoU=0.5:0.95)`评估结果。
- PP-YOLO模型训练过程中使用8 GPUs,每GPU batch size为24进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。
- PP-YOLO模型FP32的推理速度测试数据为使用`tools/export_model.py`脚本导出模型后,使用`deploy/python/infer.py`脚本中的`--run_benchnark`参数使用Paddle预测库进行推理速度benchmark测试结果, 且测试的均为不包含数据预处理和模型输出后处理(NMS)的数据(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。
- TensorRT FP16的速度测试相比于FP32去除了`yolo_box`(bbox解码)部分耗时,即不包含数据预处理,bbox解码和NMS(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。
- PP-YOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。
## 使用说明
### 1. 训练
使用8GPU通过如下命令一键式启动训练(以下命令均默认在PaddleDetection根目录运行), 通过`--eval`参数开启训练中交替评估。
```bash
python -m paddle.distributed.launch --log_dir=./ppyolo_dygraph/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml &>ppyolo_dygraph.log 2>&1 &
```
### 2. 评估
使用单GPU通过如下命令一键式评估模型在COCO val2017数据集效果
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=output/ppyolo_r50vd_dcn_1x_coco/model_final
```
我们提供了`configs/ppyolo/ppyolo_test.yml`用于评估COCO test-dev2017数据集的效果,评估COCO test-dev2017数据集的效果须先从[COCO数据集下载页](https://cocodataset.org/#download)下载test-dev2017数据集,解压到`configs/ppyolo/ppyolo_test.yml``EvalReader.dataset`中配置的路径,并使用如下命令进行评估
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=output/ppyolo_r50vd_dcn_1x_coco/model_final
```
评估结果保存于`bbox.json`中,将其压缩为zip包后通过[COCO数据集评估页](https://competitions.codalab.org/competitions/20794#participate)提交评估。
**注意:** `configs/ppyolo/ppyolo_test.yml`仅用于评估COCO test-dev数据集,不用于训练和评估COCO val2017数据集。
### 3. 推理
使用单GPU通过如下命令一键式推理图像,通过`--infer_img`指定图像路径,或通过`--infer_dir`指定目录并推理目录下所有图像
```bash
# 推理单张图像
CUDA_VISIBLE_DEVICES=0 python tools/infer.py configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --infer_img=../demo/000000014439_640x640.jpg
# 推理目录下所有图像
CUDA_VISIBLE_DEVICES=0 python tools/infer.py configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --infer_dir=../demo
```
### 4. 推理部署与benchmark
PP-YOLO模型部署及推理benchmark需要通过`tools/export_model.py`导出模型后使用Paddle预测库进行部署和推理,可通过如下命令一键式启动。
```bash
# 导出模型,默认存储于output/ppyolo目录
python tools/export_model.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams
# 预测库推理
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True
```
PP-YOLO模型benchmark测试为不包含数据预处理和网络输出后处理(NMS)的网络结构部分数据,导出模型时须指定`--exlcude_nms`来裁剪掉模型中后处理的NMS部分,通过如下命令进行模型导出和benchmark测试。
```bash
# 导出模型,通过--exclude_nms参数裁剪掉模型中的NMS部分,默认存储于output/ppyolo目录
python tools/export_model.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --exclude_nms
# FP32 benchmark测试
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True
# TensorRT FP16 benchmark测试
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/ppyolo_r50vd_dcn_1x_coco --image_file=../demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True --run_mode=trt_fp16
```
## 未来工作
1. 发布PP-YOLO-tiny模型
2. 发布更多骨干网络的PP-YOLO模型
## 附录
PP-YOLO模型相对于YOLOv3模型优化项消融实验数据如下表所示。
| 序号 | 模型 | Box AP<sup>val</sup> | Box AP<sup>test</sup> | 参数量(M) | FLOPs(G) | V100 FP32 FPS |
| :--: | :--------------------------- | :------------------: | :-------------------: | :-------: | :------: | :-----------: |
| A | YOLOv3-DarkNet53 | 38.9 | - | 59.13 | 65.52 | 58.2 |
| B | YOLOv3-ResNet50vd-DCN | 39.1 | - | 43.89 | 44.71 | 79.2 |
| C | B + LB + EMA + DropBlock | 41.4 | - | 43.89 | 44.71 | 79.2 |
| D | C + IoU Loss | 41.9 | - | 43.89 | 44.71 | 79.2 |
| E | D + IoU Aware | 42.5 | - | 43.90 | 44.71 | 74.9 |
| F | E + Grid Sensitive | 42.8 | - | 43.90 | 44.71 | 74.8 |
| G | F + Matrix NMS | 43.5 | - | 43.90 | 44.71 | 74.8 |
| H | G + CoordConv | 44.0 | - | 43.93 | 44.76 | 74.1 |
| I | H + SPP | 44.3 | 45.2 | 44.93 | 45.12 | 72.9 |
| J | I + Better ImageNet Pretrain | 44.8 | 45.2 | 44.93 | 45.12 | 72.9 |
| K | J + 2x Scheduler | 45.3 | 45.9 | 44.93 | 45.12 | 72.9 |
**注意:**
- 精度与推理速度数据均为使用输入图像尺寸为608的测试结果
- Box AP为在COCO train2017数据集训练,val2017和test-dev2017数据集上评估`mAP(IoU=0.5:0.95)`数据
- 推理速度为单卡V100上,batch size=1, 使用上述benchmark测试方法的测试结果,测试环境配置为CUDA 10.2,CUDNN 7.5.1
- [YOLOv3-DarkNet53](../yolov3/yolov3_darknet53_270e_coco.yml)精度38.9为PaddleDetection优化后的YOLOv3模型,可参见[模型库](../../../docs/MODEL_ZOO.md)
epoch: 405
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 243
- 324
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
clip_grad_by_norm: 35.
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
epoch: 811
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 649
- 730
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
clip_grad_by_norm: 35.
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo_r50vd_dcn/model_final
load_static_weights: true
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
neck: PPYOLOFPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
ResNet:
depth: 50
variant: d
return_idx: [1, 2, 3]
dcn_v2_stages: [3]
freeze_at: -1
freeze_norm: false
norm_decay: 0.
PPYOLOFPN:
feat_channels: [2048, 1280, 640]
coord_conv: true
drop_block: true
block_size: 3
keep_prob: 0.9
spp: true
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
iou_aware: true
iou_aware_factor: 0.4
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
scale_x_y: 1.05
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
loss_square: true
IouAwareLoss:
loss_weight: 1.0
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MatrixNMS
keep_top_k: 100
score_threshold: 0.01
post_threshold: 0.01
nms_top_k: -1
normalized: false
background_label: -1
worker_num: 2
TrainReader:
inputs_def:
num_max_boxes: 50
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8]}
batch_size: 24
shuffle: true
drop_last: true
mixup_epoch: 25000
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 608, 608]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_r50vd_dcn.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 16
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_r50vd_dcn.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 8
use_ema: false
TrainReader:
batch_size: 12
TrainDataset:
!COCODataSet
image_dir: train2017
# refer to https://github.com/giddyyupp/coco-minitrain
anno_path: annotations/instances_minitrain2017.json
dataset_dir: dataset/coco
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
epoch: 192
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 153
- 173
- !LinearWarmup
start_factor: 0.
steps: 4000
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_r50vd_dcn.yml',
'./_base_/optimizer_2x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 16
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_r50vd_dcn.yml',
'./_base_/ppyolo_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 16
EvalDataset:
!COCODataSet
image_dir: test2017
anno_path: annotations/image_info_test-dev2017.json
dataset_dir: dataset/coco
...@@ -23,6 +23,7 @@ import paddle ...@@ -23,6 +23,7 @@ import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from ppdet.utils.checkpoint import save_model from ppdet.utils.checkpoint import save_model
from ppdet.optimizer import ModelEMA
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -135,6 +136,15 @@ class LogPrinter(Callback): ...@@ -135,6 +136,15 @@ class LogPrinter(Callback):
class Checkpointer(Callback): class Checkpointer(Callback):
def __init__(self, model): def __init__(self, model):
super(Checkpointer, self).__init__(model) super(Checkpointer, self).__init__(model)
cfg = self.model.cfg
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.model.model, use_thres_step=True)
def on_step_end(self, status):
if self.use_ema:
self.ema.update(self.model.model)
def on_epoch_end(self, status): def on_epoch_end(self, status):
assert self.model.mode == 'train', \ assert self.model.mode == 'train', \
...@@ -147,5 +157,10 @@ class Checkpointer(Callback): ...@@ -147,5 +157,10 @@ class Checkpointer(Callback):
self.model.cfg.filename) self.model.cfg.filename)
save_name = str( save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final" epoch_id) if epoch_id != end_epoch - 1 else "model_final"
if self.use_ema:
state_dict = self.ema.apply()
save_model(state_dict, self.model.optimizer, save_dir,
save_name, epoch_id + 1)
else:
save_model(self.model.model, self.model.optimizer, save_dir, save_model(self.model.model, self.model.optimizer, save_dir,
save_name, epoch_id + 1) save_name, epoch_id + 1)
...@@ -44,8 +44,9 @@ class YOLOv3(BaseArch): ...@@ -44,8 +44,9 @@ class YOLOv3(BaseArch):
return loss return loss
def get_pred(self): def get_pred(self):
yolo_head_outs = self.yolo_head.get_outputs(self.yolo_head_outs)
bbox, bbox_num = self.post_process( bbox, bbox_num = self.post_process(
self.yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
outs = { outs = {
"bbox": bbox, "bbox": bbox,
......
...@@ -74,7 +74,7 @@ class ConvNormLayer(nn.Layer): ...@@ -74,7 +74,7 @@ class ConvNormLayer(nn.Layer):
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr, name=name + '_weights'), learning_rate=lr, name=name + "_weights"),
bias_attr=False, bias_attr=False,
name=name) name=name)
......
...@@ -7,6 +7,13 @@ from ppdet.core.workspace import register ...@@ -7,6 +7,13 @@ from ppdet.core.workspace import register
from ..backbones.darknet import ConvBNLayer from ..backbones.darknet import ConvBNLayer
def _de_sigmoid(x, eps=1e-7):
x = paddle.clip(x, eps, 1. / eps)
x = paddle.clip(1. / x - 1., eps, 1. / eps)
x = -paddle.log(x)
return x
@register @register
class YOLOv3Head(nn.Layer): class YOLOv3Head(nn.Layer):
__shared__ = ['num_classes'] __shared__ = ['num_classes']
...@@ -17,16 +24,24 @@ class YOLOv3Head(nn.Layer): ...@@ -17,16 +24,24 @@ class YOLOv3Head(nn.Layer):
[59, 119], [116, 90], [156, 198], [373, 326]], [59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
num_classes=80, num_classes=80,
loss='YOLOv3Loss'): loss='YOLOv3Loss',
iou_aware=False,
iou_aware_factor=0.4):
super(YOLOv3Head, self).__init__() super(YOLOv3Head, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.loss = loss self.loss = loss
self.iou_aware = iou_aware
self.iou_aware_factor = iou_aware_factor
self.parse_anchor(anchors, anchor_masks) self.parse_anchor(anchors, anchor_masks)
self.num_outputs = len(self.anchors) self.num_outputs = len(self.anchors)
self.yolo_outputs = [] self.yolo_outputs = []
for i in range(len(self.anchors)): for i in range(len(self.anchors)):
if self.iou_aware:
num_filters = self.num_outputs * (self.num_classes + 6)
else:
num_filters = self.num_outputs * (self.num_classes + 5) num_filters = self.num_outputs * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i) name = 'yolo_output.{}'.format(i)
yolo_output = self.add_sublayer( yolo_output = self.add_sublayer(
...@@ -62,3 +77,28 @@ class YOLOv3Head(nn.Layer): ...@@ -62,3 +77,28 @@ class YOLOv3Head(nn.Layer):
def get_loss(self, inputs, targets): def get_loss(self, inputs, targets):
return self.loss(inputs, targets, self.anchors) return self.loss(inputs, targets, self.anchors)
def get_outputs(self, outputs):
if self.iou_aware:
y = []
for i, out in enumerate(outputs):
na = len(self.anchors[i])
ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
b, c, h, w = x.shape
no = c // na
x = x.reshape((b, na, no, h, w))
ioup = ioup.reshape((b, na, 1, h, w))
obj = x[:, :, 4:5, :, :]
ioup = F.sigmoid(ioup)
obj = F.sigmoid(obj)
obj_t = (obj**(1 - self.iou_aware_factor)) * (
ioup**self.iou_aware_factor)
obj_t = _de_sigmoid(obj_t)
loc_t = x[:, :, :4, :, :]
cls_t = x[:, :, 5:, :, :]
y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
y_t = y_t.reshape((b, -1, h, w))
y.append(y_t)
return y
else:
return outputs
...@@ -612,7 +612,6 @@ class MultiClassNMS(object): ...@@ -612,7 +612,6 @@ class MultiClassNMS(object):
@register @register
@serializable @serializable
class MatrixNMS(object): class MatrixNMS(object):
__op__ = ops.matrix_nms
__append_doc__ = True __append_doc__ = True
def __init__(self, def __init__(self,
...@@ -634,6 +633,19 @@ class MatrixNMS(object): ...@@ -634,6 +633,19 @@ class MatrixNMS(object):
self.gaussian_sigma = gaussian_sigma self.gaussian_sigma = gaussian_sigma
self.background_label = background_label self.background_label = background_label
def __call__(self, bbox, score):
return ops.matrix_nms(
bboxes=bbox,
scores=score,
score_threshold=self.score_threshold,
post_threshold=self.post_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
use_gaussian=self.use_gaussian,
gaussian_sigma=self.gaussian_sigma,
background_label=self.background_label,
normalized=self.normalized)
@register @register
@serializable @serializable
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from .iou_loss import IouLoss from .iou_loss import IouLoss
...@@ -33,27 +34,15 @@ class IouAwareLoss(IouLoss): ...@@ -33,27 +34,15 @@ class IouAwareLoss(IouLoss):
max_width (int): max width of input to support random shape input max_width (int): max width of input to support random shape input
""" """
def __init__( def __init__(self, loss_weight=1.0, giou=False, diou=False, ciou=False):
self,
loss_weight=1.0,
giou=False,
diou=False,
ciou=False, ):
super(IouAwareLoss, self).__init__( super(IouAwareLoss, self).__init__(
loss_weight=loss_weight, giou=giou, diou=diou, ciou=ciou) loss_weight=loss_weight, giou=giou, diou=diou, ciou=ciou)
def __call__(self, ioup, pbox, gbox, anchor, downsample, scale=1.): def __call__(self, ioup, pbox, gbox):
b = pbox.shape[0]
ioup = ioup.reshape((b, -1))
pbox = decode_yolo(pbox, anchor, downsample)
gbox = decode_yolo(gbox, anchor, downsample)
pbox = xywh2xyxy(pbox).reshape((b, -1, 4))
gbox = xywh2xyxy(gbox).reshape((b, -1, 4))
iou = bbox_iou( iou = bbox_iou(
pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou) pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou)
iou.stop_gradient = True iou.stop_gradient = True
ioup = F.sigmoid(ioup)
loss_iou_aware = F.binary_cross_entropy_with_logits( loss_iou_aware = (-iou * paddle.log(ioup)).sum(-2, keepdim=True)
ioup, iou, reduction='none')
loss_iou_aware = loss_iou_aware * self.loss_weight loss_iou_aware = loss_iou_aware * self.loss_weight
return loss_iou_aware return loss_iou_aware
...@@ -50,12 +50,7 @@ class IouLoss(object): ...@@ -50,12 +50,7 @@ class IouLoss(object):
self.ciou = ciou self.ciou = ciou
self.loss_square = loss_square self.loss_square = loss_square
def __call__(self, pbox, gbox, anchor, downsample, scale=1.): def __call__(self, pbox, gbox):
b = pbox.shape[0]
pbox = decode_yolo(pbox, anchor, downsample)
gbox = decode_yolo(gbox, anchor, downsample)
pbox = xywh2xyxy(pbox).reshape((b, -1, 4))
gbox = xywh2xyxy(gbox).reshape((b, -1, 4))
iou = bbox_iou( iou = bbox_iou(
pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou) pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou)
if self.loss_square: if self.loss_square:
......
...@@ -26,6 +26,12 @@ from ..utils import decode_yolo, xywh2xyxy, iou_similarity ...@@ -26,6 +26,12 @@ from ..utils import decode_yolo, xywh2xyxy, iou_similarity
__all__ = ['YOLOv3Loss'] __all__ = ['YOLOv3Loss']
def bbox_transform(pbox, anchor, downsample):
pbox = decode_yolo(pbox, anchor, downsample)
pbox = xywh2xyxy(pbox)
return pbox
@register @register
class YOLOv3Loss(nn.Layer): class YOLOv3Loss(nn.Layer):
...@@ -50,11 +56,16 @@ class YOLOv3Loss(nn.Layer): ...@@ -50,11 +56,16 @@ class YOLOv3Loss(nn.Layer):
self.iou_aware_loss = iou_aware_loss self.iou_aware_loss = iou_aware_loss
def obj_loss(self, pbox, gbox, pobj, tobj, anchor, downsample): def obj_loss(self, pbox, gbox, pobj, tobj, anchor, downsample):
b, h, w, na = pbox.shape[:4] # pbox
pbox = decode_yolo(pbox, anchor, downsample) pbox = decode_yolo(pbox, anchor, downsample)
pbox = pbox.reshape((b, -1, 4))
pbox = xywh2xyxy(pbox) pbox = xywh2xyxy(pbox)
gbox = xywh2xyxy(gbox) pbox = paddle.concat(pbox, axis=-1)
b = pbox.shape[0]
pbox = pbox.reshape((b, -1, 4))
# gbox
gxy = gbox[:, :, 0:2] - gbox[:, :, 2:4] * 0.5
gwh = gbox[:, :, 0:2] + gbox[:, :, 2:4] * 0.5
gbox = paddle.concat([gxy, gwh], axis=-1)
iou = iou_similarity(pbox, gbox) iou = iou_similarity(pbox, gbox)
iou.stop_gradient = True iou.stop_gradient = True
...@@ -86,57 +97,69 @@ class YOLOv3Loss(nn.Layer): ...@@ -86,57 +97,69 @@ class YOLOv3Loss(nn.Layer):
pcls, tcls, reduction='none') pcls, tcls, reduction='none')
return loss_cls return loss_cls
def yolov3_loss(self, x, t, gt_box, anchor, downsample, scale=1., def yolov3_loss(self, p, t, gt_box, anchor, downsample, scale=1.,
eps=1e-10): eps=1e-10):
na = len(anchor) na = len(anchor)
b, c, h, w = x.shape b, c, h, w = p.shape
no = c // na
x = x.reshape((b, na, no, h, w)).transpose((0, 3, 4, 1, 2))
xy, wh, obj = x[:, :, :, :, 0:2], x[:, :, :, :, 2:4], x[:, :, :, :, 4:5]
if self.iou_aware_loss: if self.iou_aware_loss:
ioup, pcls = x[:, :, :, :, 5:6], x[:, :, :, :, 6:] ioup, p = p[:, 0:na, :, :], p[:, na:, :, :]
else: ioup = ioup.unsqueeze(-1)
pcls = x[:, :, :, :, 5:] p = p.reshape((b, na, -1, h, w)).transpose((0, 1, 3, 4, 2))
x, y = p[:, :, :, :, 0:1], p[:, :, :, :, 1:2]
t = t.transpose((0, 3, 4, 1, 2)) w, h = p[:, :, :, :, 2:3], p[:, :, :, :, 3:4]
txy, twh, tscale = t[:, :, :, :, 0:2], t[:, :, :, :, 2:4], t[:, :, :, :, obj, pcls = p[:, :, :, :, 4:5], p[:, :, :, :, 5:]
4:5]
t = t.transpose((0, 1, 3, 4, 2))
tx, ty = t[:, :, :, :, 0:1], t[:, :, :, :, 1:2]
tw, th = t[:, :, :, :, 2:3], t[:, :, :, :, 3:4]
tscale = t[:, :, :, :, 4:5]
tobj, tcls = t[:, :, :, :, 5:6], t[:, :, :, :, 6:] tobj, tcls = t[:, :, :, :, 5:6], t[:, :, :, :, 6:]
tscale_obj = tscale * tobj tscale_obj = tscale * tobj
loss = dict() loss = dict()
x = scale * F.sigmoid(x) - 0.5 * (scale - 1.)
y = scale * F.sigmoid(y) - 0.5 * (scale - 1.)
if abs(scale - 1.) < eps: if abs(scale - 1.) < eps:
loss_xy = tscale_obj * F.binary_cross_entropy_with_logits( loss_x = F.binary_cross_entropy(x, tx, reduction='none')
xy, txy, reduction='none') loss_y = F.binary_cross_entropy(y, ty, reduction='none')
loss_xy = tscale_obj * (loss_x + loss_y)
else: else:
xy = scale * F.sigmoid(xy) - 0.5 * (scale - 1.) loss_x = paddle.abs(x - tx)
loss_xy = tscale_obj * paddle.abs(xy - txy) loss_y = paddle.abs(y - ty)
loss_xy = tscale_obj * (loss_x + loss_y)
loss_xy = loss_xy.sum([1, 2, 3, 4]).mean() loss_xy = loss_xy.sum([1, 2, 3, 4]).mean()
loss_wh = tscale_obj * paddle.abs(wh - twh)
loss_w = paddle.abs(w - tw)
loss_h = paddle.abs(h - th)
loss_wh = tscale_obj * (loss_w + loss_h)
loss_wh = loss_wh.sum([1, 2, 3, 4]).mean() loss_wh = loss_wh.sum([1, 2, 3, 4]).mean()
loss['loss_loc'] = loss_xy + loss_wh loss['loss_xy'] = loss_xy
loss['loss_wh'] = loss_wh
x[:, :, :, :, 0:2] = scale * F.sigmoid(x[:, :, :, :, 0:2]) - 0.5 * (
scale - 1.)
box, tbox = x[:, :, :, :, 0:4], t[:, :, :, :, 0:4]
if self.iou_loss is not None: if self.iou_loss is not None:
# box and tbox will not change though they are modified in self.iou_loss function, so no need to clone # warn: do not modify x, y, w, h in place
loss_iou = self.iou_loss(box, tbox, anchor, downsample, scale) box, tbox = [x, y, w, h], [tx, ty, tw, th]
loss_iou = loss_iou * tscale_obj.reshape((b, -1)) pbox = bbox_transform(box, anchor, downsample)
loss_iou = loss_iou.sum(-1).mean() gbox = bbox_transform(tbox, anchor, downsample)
loss_iou = self.iou_loss(pbox, gbox)
loss_iou = loss_iou * tscale_obj
loss_iou = loss_iou.sum([1, 2, 3, 4]).mean()
loss['loss_iou'] = loss_iou loss['loss_iou'] = loss_iou
if self.iou_aware_loss is not None: if self.iou_aware_loss is not None:
# box and tbox will not change though they are modified in self.iou_aware_loss function, so no need to clone box, tbox = [x, y, w, h], [tx, ty, tw, th]
loss_iou_aware = self.iou_aware_loss(ioup, box, tbox, anchor, pbox = bbox_transform(box, anchor, downsample)
downsample, scale) gbox = bbox_transform(tbox, anchor, downsample)
loss_iou_aware = loss_iou_aware * tobj.reshape((b, -1)) loss_iou_aware = self.iou_aware_loss(ioup, pbox, gbox)
loss_iou_aware = loss_iou_aware.sum(-1).mean() loss_iou_aware = loss_iou_aware * tobj
loss_iou_aware = loss_iou_aware.sum([1, 2, 3, 4]).mean()
loss['loss_iou_aware'] = loss_iou_aware loss['loss_iou_aware'] = loss_iou_aware
box = [x, y, w, h]
loss_obj = self.obj_loss(box, gt_box, obj, tobj, anchor, downsample) loss_obj = self.obj_loss(box, gt_box, obj, tobj, anchor, downsample)
loss_obj = loss_obj.sum(-1).mean() loss_obj = loss_obj.sum(-1).mean()
loss['loss_obj'] = loss_obj loss['loss_obj'] = loss_obj
...@@ -152,7 +175,8 @@ class YOLOv3Loss(nn.Layer): ...@@ -152,7 +175,8 @@ class YOLOv3Loss(nn.Layer):
yolo_losses = dict() yolo_losses = dict()
for x, t, anchor, downsample in zip(inputs, gt_targets, anchors, for x, t, anchor, downsample in zip(inputs, gt_targets, anchors,
self.downsample): self.downsample):
yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample) yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample,
self.scale_x_y)
for k, v in yolo_loss.items(): for k, v in yolo_loss.items():
if k in yolo_losses: if k in yolo_losses:
yolo_losses[k] += v yolo_losses[k] += v
......
...@@ -18,6 +18,7 @@ import paddle.nn.functional as F ...@@ -18,6 +18,7 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..backbones.darknet import ConvBNLayer from ..backbones.darknet import ConvBNLayer
import numpy as np
class YoloDetBlock(nn.Layer): class YoloDetBlock(nn.Layer):
...@@ -62,6 +63,101 @@ class YoloDetBlock(nn.Layer): ...@@ -62,6 +63,101 @@ class YoloDetBlock(nn.Layer):
return route, tip return route, tip
class SPP(nn.Layer):
def __init__(self, ch_in, ch_out, k, pool_size, norm_type, name):
super(SPP, self).__init__()
self.pool = []
for size in pool_size:
pool = self.add_sublayer(
'{}.pool1'.format(name),
nn.MaxPool2D(
kernel_size=size,
stride=1,
padding=size // 2,
ceil_mode=False))
self.pool.append(pool)
self.conv = ConvBNLayer(
ch_in, ch_out, k, padding=k // 2, norm_type=norm_type, name=name)
def forward(self, x):
outs = [x]
for pool in self.pool:
outs.append(pool(x))
y = paddle.concat(outs, axis=1)
y = self.conv(y)
return y
class DropBlock(nn.Layer):
def __init__(self, block_size, keep_prob, name):
super(DropBlock, self).__init__()
self.block_size = block_size
self.keep_prob = keep_prob
self.name = name
def forward(self, x):
if not self.training or self.keep_prob == 1:
return x
else:
gamma = (1. - self.keep_prob) / (self.block_size**2)
for s in x.shape[2:]:
gamma *= s / (s - self.block_size + 1)
matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
mask_inv = F.max_pool2d(
matrix, self.block_size, stride=1, padding=self.block_size // 2)
mask = 1. - mask_inv
y = x * mask * (mask.numel() / mask.sum())
return y
class CoordConv(nn.Layer):
def __init__(self, ch_in, ch_out, filter_size, padding, norm_type, name):
super(CoordConv, self).__init__()
self.conv = ConvBNLayer(
ch_in + 2,
ch_out,
filter_size=filter_size,
padding=padding,
norm_type=norm_type,
name=name)
def forward(self, x):
b = x.shape[0]
h = x.shape[2]
w = x.shape[3]
gx = paddle.arange(w, dtype='float32') / (w - 1.) * 2.0 - 1.
gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
gx.stop_gradient = True
gy = paddle.arange(h, dtype='float32') / (h - 1.) * 2.0 - 1.
gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w])
gy.stop_gradient = True
y = paddle.concat([x, gx, gy], axis=1)
y = self.conv(y)
return y
class PPYOLODetBlock(nn.Layer):
def __init__(self, cfg, name):
super(PPYOLODetBlock, self).__init__()
self.conv_module = nn.Sequential()
for idx, (conv_name, layer, args, kwargs) in enumerate(cfg[:-1]):
kwargs.update(name='{}.{}'.format(name, conv_name))
self.conv_module.add_sublayer(conv_name, layer(*args, **kwargs))
conv_name, layer, args, kwargs = cfg[-1]
kwargs.update(name='{}.{}'.format(name, conv_name))
self.tip = layer(*args, **kwargs)
def forward(self, inputs):
route = self.conv_module(inputs)
tip = self.tip(route)
return route, tip
@register @register
@serializable @serializable
class YOLOv3FPN(nn.Layer): class YOLOv3FPN(nn.Layer):
...@@ -114,3 +210,101 @@ class YOLOv3FPN(nn.Layer): ...@@ -114,3 +210,101 @@ class YOLOv3FPN(nn.Layer):
route = F.interpolate(route, scale_factor=2.) route = F.interpolate(route, scale_factor=2.)
return yolo_feats return yolo_feats
@register
@serializable
class PPYOLOFPN(nn.Layer):
__shared__ = ['norm_type']
def __init__(self,
feat_channels=[2048, 1280, 640],
norm_type='bn',
**kwargs):
super(PPYOLOFPN, self).__init__()
assert len(feat_channels) > 0, "feat_channels length should > 0"
self.feat_channels = feat_channels
self.num_blocks = len(feat_channels)
# parse kwargs
self.coord_conv = kwargs.get('coord_conv', False)
self.drop_block = kwargs.get('drop_block', False)
if self.drop_block:
self.block_size = kwargs.get('block_size', 3)
self.keep_prob = kwargs.get('keep_prob', 0.9)
self.spp = kwargs.get('spp', False)
if self.coord_conv:
ConvLayer = CoordConv
else:
ConvLayer = ConvBNLayer
if self.drop_block:
dropblock_cfg = [[
'dropblock', DropBlock, [self.block_size, self.keep_prob],
dict()
]]
else:
dropblock_cfg = []
self.yolo_blocks = []
self.routes = []
for i, ch_in in enumerate(self.feat_channels):
channel = 64 * (2**self.num_blocks) // (2**i)
base_cfg = [
# name of layer, Layer, args
['conv0', ConvLayer, [ch_in, channel, 1]],
['conv1', ConvBNLayer, [channel, channel * 2, 3]],
['conv2', ConvLayer, [channel * 2, channel, 1]],
['conv3', ConvBNLayer, [channel, channel * 2, 3]],
['route', ConvLayer, [channel * 2, channel, 1]],
['tip', ConvLayer, [channel, channel * 2, 3]]
]
for conf in base_cfg:
filter_size = conf[-1][-1]
conf.append(dict(padding=filter_size // 2, norm_type=norm_type))
if i == 0:
if self.spp:
pool_size = [5, 9, 13]
spp_cfg = [[
'spp', SPP,
[channel * (len(pool_size) + 1), channel, 1], dict(
pool_size=pool_size, norm_type=norm_type)
]]
else:
spp_cfg = []
cfg = base_cfg[0:3] + spp_cfg + base_cfg[
3:4] + dropblock_cfg + base_cfg[4:6]
else:
cfg = base_cfg[0:2] + dropblock_cfg + base_cfg[2:6]
name = 'yolo_block.{}'.format(i)
yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name))
self.yolo_blocks.append(yolo_block)
if i < self.num_blocks - 1:
name = 'yolo_transition.{}'.format(i)
route = self.add_sublayer(
name,
ConvBNLayer(
ch_in=channel,
ch_out=channel // 2,
filter_size=1,
stride=1,
padding=0,
norm_type=norm_type,
name=name))
self.routes.append(route)
def forward(self, blocks):
assert len(blocks) == self.num_blocks
blocks = blocks[::-1]
yolo_feats = []
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
yolo_feats.append(tip)
if i < self.num_blocks - 1:
route = self.routes[i](route)
route = F.interpolate(route, scale_factor=2.)
return yolo_feats
\ No newline at end of file
...@@ -1209,13 +1209,11 @@ def matrix_nms(bboxes, ...@@ -1209,13 +1209,11 @@ def matrix_nms(bboxes,
use_gaussian, 'keep_top_k', keep_top_k, 'normalized', use_gaussian, 'keep_top_k', keep_top_k, 'normalized',
normalized) normalized)
out, index, rois_num = core.ops.matrix_nms(bboxes, scores, *attrs) out, index, rois_num = core.ops.matrix_nms(bboxes, scores, *attrs)
if return_index: if not return_index:
if return_rois_num: index = None
return out, index, rois_num if not return_rois_num:
return out, index rois_num = None
if return_rois_num: return out, rois_num, index
return out, rois_num
return out
else: else:
helper = LayerHelper('matrix_nms', **locals()) helper = LayerHelper('matrix_nms', **locals())
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
...@@ -1242,13 +1240,11 @@ def matrix_nms(bboxes, ...@@ -1242,13 +1240,11 @@ def matrix_nms(bboxes,
outputs=outputs) outputs=outputs)
output.stop_gradient = True output.stop_gradient = True
if return_index: if not return_index:
if return_rois_num: index = None
return output, index, rois_num if not return_rois_num:
return output, index rois_num = None
if return_rois_num: return output, rois_num, index
return output, rois_num
return output
def bipartite_match(dist_matrix, def bipartite_match(dist_matrix,
......
...@@ -22,10 +22,12 @@ import math ...@@ -22,10 +22,12 @@ import math
def xywh2xyxy(box): def xywh2xyxy(box):
out = paddle.zeros_like(box) x, y, w, h = box
out[:, :, 0:2] = box[:, :, 0:2] - box[:, :, 2:4] / 2 x1 = x - w * 0.5
out[:, :, 2:4] = box[:, :, 0:2] + box[:, :, 2:4] / 2 y1 = y - h * 0.5
return out x2 = x + w * 0.5
y2 = y + h * 0.5
return [x1, y1, x2, y2]
def make_grid(h, w, dtype): def make_grid(h, w, dtype):
...@@ -37,27 +39,27 @@ def decode_yolo(box, anchor, downsample_ratio): ...@@ -37,27 +39,27 @@ def decode_yolo(box, anchor, downsample_ratio):
"""decode yolo box """decode yolo box
Args: Args:
box (Tensor): pred with the shape [b, h, w, na, 4] box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
anchor (list): anchor with the shape [na, 2] anchor (list): anchor with the shape [na, 2]
downsample_ratio (int): downsample ratio, default 32 downsample_ratio (int): downsample ratio, default 32
scale (float): scale, default 1. scale (float): scale, default 1.
Return: Return:
box (Tensor): decoded box, with the shape [b, h, w, na, 4] box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
""" """
h, w, na = box.shape[1:4] x, y, w, h = box
grid = make_grid(h, w, box.dtype).reshape((1, h, w, 1, 2)) na, grid_h, grid_w = x.shape[1:4]
box[:, :, :, :, 0:2] = box[:, :, :, :, :2] + grid grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2))
box[:, :, :, :, 0] = box[:, :, :, :, 0] / w x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
box[:, :, :, :, 1] = box[:, :, :, :, 1] / h y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
anchor = paddle.to_tensor(anchor) anchor = paddle.to_tensor(anchor)
anchor = paddle.cast(anchor, box.dtype) anchor = paddle.cast(anchor, x.dtype)
anchor = anchor.reshape((1, 1, 1, na, 2)) anchor = anchor.reshape((1, na, 1, 1, 2))
box[:, :, :, :, 2:4] = paddle.exp(box[:, :, :, :, 2:4]) * anchor w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
box[:, :, :, :, 2] = box[:, :, :, :, 2] / (downsample_ratio * w) h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
box[:, :, :, :, 3] = box[:, :, :, :, 3] / (downsample_ratio * h)
return box return [x1, y1, w1, h1]
def iou_similarity(box1, box2, eps=1e-9): def iou_similarity(box1, box2, eps=1e-9):
...@@ -87,48 +89,56 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): ...@@ -87,48 +89,56 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
"""calculate the iou of box1 and box2 """calculate the iou of box1 and box2
Args: Args:
box1 (Tensor): box1 with the shape (N, M, 4) box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
box2 (Tensor): box1 with the shape (N, M, 4) box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
giou (bool): whether use giou or not, default False giou (bool): whether use giou or not, default False
diou (bool): whether use diou or not, default False diou (bool): whether use diou or not, default False
ciou (bool): whether use ciou or not, default False ciou (bool): whether use ciou or not, default False
eps (float): epsilon to avoid divide by zero eps (float): epsilon to avoid divide by zero
Return: Return:
iou (Tensor): iou of box1 and box1, with the shape (N, M) iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
""" """
px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4] px1, py1, px2, py2 = box1
gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4] gx1, gy1, gx2, gy2 = box2
x1y1 = paddle.maximum(px1y1, gx1y1) x1 = paddle.maximum(px1, gx1)
x2y2 = paddle.minimum(px2y2, gx2y2) y1 = paddle.maximum(py1, gy1)
x2 = paddle.minimum(px2, gx2)
y2 = paddle.minimum(py2, gy2)
overlap = (x2 - x1) * (y2 - y1)
overlap = overlap.clip(0)
area1 = (px2 - px1) * (py2 - py1)
area1 = area1.clip(0)
area2 = (gx2 - gx1) * (gy2 - gy1)
area2 = area2.clip(0)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps union = area1 + area2 - overlap + eps
iou = overlap / union iou = overlap / union
if giou or ciou or diou: if giou or ciou or diou:
# convex w, h # convex w, h
cwh = paddle.maximum(px2y2, gx2y2) - paddle.minimum(px1y1, gx1y1) cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
if ciou or diou: ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
if giou:
c_area = cw * ch + eps
return iou - (c_area - union) / c_area
else:
# convex diagonal squared # convex diagonal squared
c2 = (cwh**2).sum(2) + eps c2 = cw**2 + ch**2 + eps
# center distance # center distance
rho2 = ((px1y1 + px2y2 - gx1y1 - gx2y2)**2).sum(2) / 4 rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
if diou: if diou:
return iou - rho2 / c2 return iou - rho2 / c2
elif ciou: else:
wh1 = px2y2 - px1y1 w1, h1 = px2 - px1, py2 - py1 + eps
wh2 = gx2y2 - gx1y1 w2, h2 = gx2 - gx1, gy2 - gy1 + eps
w1, h1 = wh1[:, :, 0], wh1[:, :, 1] + eps delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
w2, h2 = wh2[:, :, 0], wh2[:, :, 1] + eps v = (4 / math.pi**2) * paddle.pow(delta, 2)
v = (4 / math.pi**2) * paddle.pow(
paddle.atan(w1 / h1) - paddle.atan(w2 / h2), 2)
alpha = v / (1 + eps - iou + v) alpha = v / (1 + eps - iou + v)
alpha.stop_gradient = True alpha.stop_gradient = True
return iou - (rho2 / c2 + v * alpha) return iou - (rho2 / c2 + v * alpha)
else:
c_area = cwh.prod(2) + eps
return iou - (c_area - union) / c_area
else: else:
return iou return iou
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -202,7 +203,7 @@ class OptimizerBuilder(): ...@@ -202,7 +203,7 @@ class OptimizerBuilder():
def __call__(self, learning_rate, params=None): def __call__(self, learning_rate, params=None):
if self.clip_grad_by_norm is not None: if self.clip_grad_by_norm is not None:
grad_clip = nn.GradientClipByGlobalNorm( grad_clip = nn.ClipGradByGlobalNorm(
clip_norm=self.clip_grad_by_norm) clip_norm=self.clip_grad_by_norm)
else: else:
grad_clip = None grad_clip = None
...@@ -223,3 +224,38 @@ class OptimizerBuilder(): ...@@ -223,3 +224,38 @@ class OptimizerBuilder():
weight_decay=regularization, weight_decay=regularization,
grad_clip=grad_clip, grad_clip=grad_clip,
**optim_args) **optim_args)
class ModelEMA(object):
def __init__(self, decay, model, use_thres_step=False):
self.step = 0
self.decay = decay
self.state_dict = dict()
for k, v in model.state_dict().items():
self.state_dict[k] = paddle.zeros_like(v)
self.use_thres_step = use_thres_step
def update(self, model):
if self.use_thres_step:
decay = min(self.decay, (1 + self.step) / (10 + self.step))
else:
decay = self.decay
self._decay = decay
model_dict = model.state_dict()
for k, v in self.state_dict.items():
if '_mean' not in k and '_variance' not in k:
v = decay * v + (1 - decay) * model_dict[k]
v.stop_gradient = True
self.state_dict[k] = v
else:
self.state_dict[k] = model_dict[k]
self.step += 1
def apply(self):
state_dict = dict()
for k, v in self.state_dict.items():
if '_mean' not in k and '_variance' not in k:
v = v / (1 - self._decay**self.step)
v.stop_gradient = True
state_dict[k] = v
return state_dict
...@@ -23,6 +23,7 @@ import time ...@@ -23,6 +23,7 @@ import time
import re import re
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn
from .download import get_weights_path from .download import get_weights_path
from .logger import setup_logger from .logger import setup_logger
...@@ -169,7 +170,12 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch): ...@@ -169,7 +170,12 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
if isinstance(model, nn.Layer):
paddle.save(model.state_dict(), save_path + ".pdparams") paddle.save(model.state_dict(), save_path + ".pdparams")
else:
assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict'
paddle.save(model, save_path + ".pdparams")
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt") paddle.save(state_dict, save_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册