未验证 提交 defcfec1 编写于 作者: K Kaipeng Deng 提交者: GitHub

Update yolov3 tiny (#1033)

* add PPYOLO
Co-authored-by: Nlongxiang <longxiang@baidu.com>
上级 38d420bb
...@@ -83,7 +83,7 @@ ...@@ -83,7 +83,7 @@
以下为选取各模型结构和骨干网络的代表模型COCO数据集精度mAP和单卡Tesla V100上预测速度(FPS)关系图。 以下为选取各模型结构和骨干网络的代表模型COCO数据集精度mAP和单卡Tesla V100上预测速度(FPS)关系图。
<div align="center"> <div align="center">
<img src="docs/images/map_fps.png" /> <img src="docs/images/map_fps.png" width=800 />
</div> </div>
**说明:** **说明:**
...@@ -92,6 +92,12 @@ ...@@ -92,6 +92,12 @@
- PaddleDetection增强版`YOLOv3-ResNet50vd-DCN`在COCO数据集mAP高于原作10.6个绝对百分点,推理速度为61.3FPS,快于原作约70% - PaddleDetection增强版`YOLOv3-ResNet50vd-DCN`在COCO数据集mAP高于原作10.6个绝对百分点,推理速度为61.3FPS,快于原作约70%
- 图中模型均可在[模型库](#模型库)中获取 - 图中模型均可在[模型库](#模型库)中获取
以下为PaddleDetection发布的精度和预测速度优于YOLOv4模型的PPYOLO与前沿目标检测算法的COCO数据集精度与单卡Tesla V100预测速度(FPS)关系图, PPYOLO模型在[COCO](http://cocodataset.org) test2019数据集上精度达到45.2%,在单卡V100上FP32推理速度为72.9 FPS,详细信息见[PPYOLO模型](configs/ppyolo/README.md)
<div align="center">
<img src="docs/images/ppyolo_map_fps.png" width=600 />
</div>
## 文档教程 ## 文档教程
### 入门教程 ### 入门教程
...@@ -129,6 +135,7 @@ ...@@ -129,6 +135,7 @@
- [Anchor free模型](configs/anchor_free/README.md) - [Anchor free模型](configs/anchor_free/README.md)
- [人脸检测模型](docs/featured_model/FACE_DETECTION.md) - [人脸检测模型](docs/featured_model/FACE_DETECTION.md)
- [YOLOv3增强模型](docs/featured_model/YOLOv3_ENHANCEMENT.md): COCO mAP高达43.6%,原论文精度为33.0% - [YOLOv3增强模型](docs/featured_model/YOLOv3_ENHANCEMENT.md): COCO mAP高达43.6%,原论文精度为33.0%
- [PPYOLO模型](configs/ppyolo/README.md): COCO mAP高达45.3%,单卡Tesla V100预测速度高达72.9 FPS
- [行人检测预训练模型](docs/featured_model/CONTRIB_cn.md) - [行人检测预训练模型](docs/featured_model/CONTRIB_cn.md)
- [车辆检测预训练模型](docs/featured_model/CONTRIB_cn.md) - [车辆检测预训练模型](docs/featured_model/CONTRIB_cn.md)
- [Objects365 2019 Challenge夺冠模型](docs/featured_model/champion_model/CACascadeRCNN.md) - [Objects365 2019 Challenge夺冠模型](docs/featured_model/champion_model/CACascadeRCNN.md)
......
...@@ -96,7 +96,7 @@ Advanced Features: ...@@ -96,7 +96,7 @@ Advanced Features:
The following is the relationship between COCO mAP and FPS on Tesla V100 of representative models of each architectures and backbones. The following is the relationship between COCO mAP and FPS on Tesla V100 of representative models of each architectures and backbones.
<div align="center"> <div align="center">
<img src="docs/images/map_fps.png" /> <img src="docs/images/map_fps.png" width=800 />
</div> </div>
**NOTE:** **NOTE:**
...@@ -105,6 +105,12 @@ The following is the relationship between COCO mAP and FPS on Tesla V100 of repr ...@@ -105,6 +105,12 @@ The following is the relationship between COCO mAP and FPS on Tesla V100 of repr
- The enhanced `YOLOv3-ResNet50vd-DCN` is 10.6 absolute percentage points higher than paper on COCO mAP, and inference speed is nearly 70% faster than the darknet framework - The enhanced `YOLOv3-ResNet50vd-DCN` is 10.6 absolute percentage points higher than paper on COCO mAP, and inference speed is nearly 70% faster than the darknet framework
- All these models can be get in [Model Zoo](#Model-Zoo) - All these models can be get in [Model Zoo](#Model-Zoo)
The following is the relationship between COCO mAP and FPS on Tesla V100 of SOTA object detecters and PPYOLO, which is faster and has better performance than YOLOv4, and reached mAP(0.5:0.95) as 45.2% on COCO test2019 dataset and 72.9 FPS on single Test V100. Please refer to [PPYOLO](configs/ppyolo/README.md) for details.
<div align="center">
<img src="docs/images/ppyolo_map_fps.png" width=600 />
</div>
## Tutorials ## Tutorials
...@@ -146,6 +152,7 @@ The following is the relationship between COCO mAP and FPS on Tesla V100 of repr ...@@ -146,6 +152,7 @@ The following is the relationship between COCO mAP and FPS on Tesla V100 of repr
- [Pretrained models for pedestrian detection](docs/featured_model/CONTRIB.md) - [Pretrained models for pedestrian detection](docs/featured_model/CONTRIB.md)
- [Pretrained models for vehicle detection](docs/featured_model/CONTRIB.md) - [Pretrained models for vehicle detection](docs/featured_model/CONTRIB.md)
- [YOLOv3 enhanced model](docs/featured_model/YOLOv3_ENHANCEMENT.md): Compared to MAP of 33.0% in paper, enhanced YOLOv3 reaches the MAP of 43.6%, and inference speed is improved as well - [YOLOv3 enhanced model](docs/featured_model/YOLOv3_ENHANCEMENT.md): Compared to MAP of 33.0% in paper, enhanced YOLOv3 reaches the MAP of 43.6%, and inference speed is improved as well
- [PPYOLO](configs/ppyolo/README.md): PPYOLO reeached mAP as 45.3% on COCO dataset,and 72.9 FPS on single Tesla V100
- [Objects365 2019 Challenge champion model](docs/featured_model/champion_model/CACascadeRCNN.md) - [Objects365 2019 Challenge champion model](docs/featured_model/champion_model/CACascadeRCNN.md)
- [Best single model of Open Images 2019-Object Detction](docs/featured_model/champion_model/OIDV5_BASELINE_MODEL.md) - [Best single model of Open Images 2019-Object Detction](docs/featured_model/champion_model/OIDV5_BASELINE_MODEL.md)
- [Practical Server-side detection method](configs/rcnn_enhance/README_en.md): Inference speed on single V100 GPU can reach 20FPS when COCO mAP is 47.8%. - [Practical Server-side detection method](configs/rcnn_enhance/README_en.md): Inference speed on single V100 GPU can reach 20FPS when COCO mAP is 47.8%.
......
# PPYOLO 模型
## 内容
- [简介](#简介)
- [模型库与基线](#模型库与基线)
- [使用说明](#使用说明)
- [未来工作](#未来工作)
- [附录](#附录)
## 简介
[PPYOLO](https://arxiv.org/abs/2007.12099)的PaddleDetection优化和改进的YOLOv3的模型,其精度(COCO数据集mAP)和推理速度均优于[YOLOv4](https://arxiv.org/abs/2004.10934)模型,要求使用PaddlePaddle 1.8.4(2020年8月中旬发布)或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)
PPYOLO在[COCO](http://cocodataset.org) test2019数据集上精度达到45.2%,在单卡V100上FP32推理速度为72.9 FPS, V100上开启TensorRT下FP16推理速度为155.6 FPS。
<div align="center">
<img src="../../docs/images/ppyolo_map_fps.png" width=500 />
</div>
PPYOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- 更优的骨干网络: ResNet50vd-DCN
- 更大的训练batch size: 8 GPU,每GPU batch_size=24,对应调整学习率和迭代轮数
- [Drop Block](https://arxiv.org/abs/1810.12890)
- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp)
- [IoU Loss](https://arxiv.org/pdf/1902.09630.pdf)
- [Grid Sensitive](https://arxiv.org/abs/2004.10934)
- [Matrix NMS](https://arxiv.org/pdf/2003.10152.pdf)
- [CoordConv](https://arxiv.org/abs/1807.03247)
- [Spatial Pyramid Pooling](https://arxiv.org/abs/1406.4729)
- 更优的预训练模型
## 模型库
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :----: | :------------: | :---------------------: | :------: | :------: |
| YOLOv4(AlexyAB) | - | - | CSPDarknet | 608 | 43.5 | 62 | 105.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_csdarknet.yml) |
| YOLOv4(AlexyAB) | - | - | CSPDarknet | 512 | 43.0 | 83 | 138.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_csdarknet.yml) |
| YOLOv4(AlexyAB) | - | - | CSPDarknet | 416 | 41.2 | 96 | 164.0 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_csdarknet.yml) |
| YOLOv4(AlexyAB) | - | - | CSPDarknet | 320 | 38.0 | 123 | 199.0 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_csdarknet.yml) |
| PPYOLO | 8 | 24 | ResNet50vd | 608 | 45.2 | 72.9 | 155.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PPYOLO | 8 | 24 | ResNet50vd | 512 | 44.4 | 89.9 | 188.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PPYOLO | 8 | 24 | ResNet50vd | 416 | 42.5 | 109.1 | 215.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PPYOLO | 8 | 24 | ResNet50vd | 320 | 39.3 | 132.2 | 242.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
**注意:**
- PPYOLO模型使用COCO数据集中train2017作为训练集,使用test2019左右测试集。
- PPYOLO模型训练过程中使用8GPU,每GPU batch size为24进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PPYOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。
- PPYOLO模型推理速度测试数据为使用`tools/export_model.py`脚本导出模型后,使用`deploy/python/infer.py`脚本中的`--run_benchnark`参数使用Paddle预测库进行推理速度benchmark测试结果, 且测试的均为不包含数据预处理和模型输出后处理(NMS)的数据(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。
- TensorRT FP16的速度测试相比于FP32去除了`yolo_box`(bbox解码)部分耗时,即不包含数据预处理,bbox解码和NMS(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。
- YOLOv4(AlexyAB)模型精度和V100 FP32推理速度数据使用[YOLOv4 github库](https://github.com/AlexeyAB/darknet)提供的单卡V100上精度速度测试数据,V100 TensorRT FP16推理速度为使用[AlexyAB/darknet]库中tkDNN配置于单卡V100上的测试结果。
- YOLOv4(AlexyAB)行`模型下载``配置文件`为PaddleDetection复现的YOLOv4模型,目前评估精度已对齐,支持finetune,训练精度对齐中,可参见[PaddleDetection YOLOv4 模型](../yolov4/README.md)
## 使用说明
### 1. 训练
使用8GPU通过如下命令一键式启动训练(以下命令均默认在PaddleDetection根目录运行), 通过`--eval`参数开启训练中交替评估。
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tools/train.py -c configs/ppyolo/ppyolo.yml --eval
```
### 2. 评估
使用单GPU通过如下命令一键式评估模型效果
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo.yml -o weights=output/ppyolo/best_model
```
### 3. 推理
使用单GPU通过如下命令一键式推理图像,通过`--infer_img`指定图像路径,或通过`--infer_dir`指定目录并推理目录下所有图像
```bash
# 推理单张图像
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/ppyolo/ppyolo.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams --infer_img=demo/000000014439_640x640.jpg
# 推理目录下所有图像
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/ppyolo/ppyolo.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams --infer_dir=demo
```
### 4. 推理部署与benchmark
PPYOLO模型部署及推理benchmark需要通过`tools/export_model.py`导出模型后使用Paddle预测库进行部署和推理,可通过如下命令一键式启动。
```bash
# 导出模型,默认存储于output/ppyolo目录
python tools/export_model.py -c configs/ppyolo/ppyolo.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams
# 预测库推理
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output/ppyolo --image_file=demo/000000014439_640x640.jpg --use_gpu=True
```
PPYOLO模型benchmark测试为不包含数据预处理和网络输出后处理(NMS)的网络结构部分数据,导出模型时须指定`--exlcude_nms`来裁剪掉模型中后处理的NMS部分,通过如下命令进行模型导出和benchmark测试。
```bash
# 导出模型,通过--exclude_nms参数裁剪掉模型中的NMS部分,默认存储于output/ppyolo目录
python tools/export_model.py -c configs/ppyolo/ppyolo.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams --exclude_nms
# FP32 benchmark测试
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output/ppyolo --image_file=demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True
# TensorRT FP16 benchmark测试
CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output/ppyolo --image_file=demo/000000014439_640x640.jpg --use_gpu=True --run_benchmark=True --run_mode=trt_fp16
```
## 未来工作
1. 发布PPYOLO-tiny模型
2. 发布更多骨干网络的PPYOLO及PPYOLO-tiny模型
## 附录
PPYOLO模型相对于YOLOv3模型优化项消融实验数据如下表所示。
| 序号 | 模型 | Box AP | 参数量(M) | FLOPs(G) | V100 FP32 FPS |
| :--: | :--------------------------- | :----: | :-------: | :------: | :-----------: |
| A | YOLOv3-DarkNet53 | 38.9 | 59.13 | 65.52 | 58.2 |
| B | YOLOv3-ResNet50vd-DCN | 39.1 | 43.89 | 44.71 | 79.2 |
| C | B + LB + EMA + DropBlock | 41.4 | 43.89 | 44.71 | 79.2 |
| D | C + IoU Loss | 41.9 | 43.89 | 44.71 | 79.2 |
| E | D + IoU Aware | 42.5 | 43.90 | 44.71 | 74.9 |
| F | E + Grid Sensitive | 42.8 | 43.90 | 44.71 | 74.8 |
| G | F + Matrix NMS | 43.5 | 43.90 | 44.71 | 74.8 |
| H | G + CoordConv | 44.0 | 43.93 | 44.76 | 74.1 |
| I | H + SPP | 44.3 | 44.93 | 45.12 | 72.9 |
| J | I + Better ImageNet Pretrain | 44.6 | 44.93 | 45.12 | 72.9 |
**注意:**
- 精度与推理速度数据均为使用输入图像尺寸为608的测试结果
- Box AP为在COCO train2017数据集训练,val2017数据集上评估数据
- 推理速度为单卡V100上,batch size=1, 使用上述benchmark测试方法的测试结果,测试环境配置为CUDA 10.2,CUDNN 7.5.1
- [YOLOv3-DarkNet53](../yolov3_darknet.yml)精度38.9为PaddleDetection优化后的YOLOv3模型,可参见[模型库](../../docs/MODEL_ZOO_cn.md)
architecture: YOLOv3
use_gpu: true
max_iters: 500000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms: MatrixNMS
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
MatrixNMS:
background_label: -1
keep_top_k: 100
normalized: false
score_threshold: 0.01
post_threshold: 0.01
LearningRate:
base_lr: 0.00333
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 400000
- 450000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo_lb/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms: MatrixNMS
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
MatrixNMS:
background_label: -1
keep_top_k: 100
normalized: false
score_threshold: 0.01
post_threshold: 0.01
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader_lb.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
batch_size: 24
shuffle: true
mixup_epoch: 25000
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 8
bufsize: 4
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
batch_size: 24
shuffle: true
mixup_epoch: 25000
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 8
bufsize: 4
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 20
log_iter: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 18
feature_maps: [4, 5]
variant: d
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
norm_decay: 0.
conv_block_num: 0
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
nms: MatrixNMS
drop_block: true
YOLOv3Loss:
batch_size: 32
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
MatrixNMS:
background_label: -1
keep_top_k: 100
normalized: false
score_threshold: 0.01
post_threshold: 0.01
LearningRate:
base_lr: 0.004
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: train_data/dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
shuffle: true
mixup_epoch: 500
drop_last: true
worker_num: 16
bufsize: 8
use_process: true
...@@ -466,7 +466,12 @@ class Detector(): ...@@ -466,7 +466,12 @@ class Detector():
results['masks'] = np_masks results['masks'] = np_masks
return results return results
def predict(self, image, threshold=0.5, warmup=0, repeats=1): def predict(self,
image,
threshold=0.5,
warmup=0,
repeats=1,
run_benchmark=False):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image (str/np.ndarray): path of image/ np.ndarray read by cv2
...@@ -500,7 +505,7 @@ class Detector(): ...@@ -500,7 +505,7 @@ class Detector():
np_masks = np.array(outs[1]) np_masks = np.array(outs[1])
else: else:
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(inputs)): for i in range(len(input_names)):
input_tensor = self.predictor.get_input_tensor(input_names[i]) input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
...@@ -528,12 +533,15 @@ class Detector(): ...@@ -528,12 +533,15 @@ class Detector():
ms = (t2 - t1) * 1000.0 / repeats ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms)) print("Inference: {} ms per batch image".format(ms))
if reduce(lambda x, y: x * y, np_boxes.shape) < 6: # do not perform postprocess in benchmark mode
print('[WARNNING] No object detected.') results = []
results = {'boxes': np.array([])} if not run_benchmark:
else: if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
results = self.postprocess( print('[WARNNING] No object detected.')
np_boxes, np_masks, im_info, threshold=threshold) results = {'boxes': np.array([])}
else:
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
return results return results
...@@ -543,7 +551,11 @@ def predict_image(): ...@@ -543,7 +551,11 @@ def predict_image():
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode) FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
detector.predict( detector.predict(
FLAGS.image_file, FLAGS.threshold, warmup=100, repeats=100) FLAGS.image_file,
FLAGS.threshold,
warmup=100,
repeats=100,
run_benchmark=True)
else: else:
results = detector.predict(FLAGS.image_file, FLAGS.threshold) results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize( visualize(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import fluid import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.core.workspace import register from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
from ppdet.modeling.ops import DropBlock from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from .iou_aware import get_iou_aware_score from ppdet.core.workspace import register
try: from ppdet.modeling.ops import DropBlock
from collections.abc import Sequence from .iou_aware import get_iou_aware_score
except Exception: try:
from collections import Sequence from collections.abc import Sequence
from ppdet.utils.check import check_version except Exception:
from collections import Sequence
__all__ = ['YOLOv3Head', 'YOLOv4Head'] from ppdet.utils.check import check_version
__all__ = ['YOLOv3Head', 'YOLOv4Head']
@register
class YOLOv3Head(object):
""" @register
Head block for YOLOv3 network class YOLOv3Head(object):
"""
Args: Head block for YOLOv3 network
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes Args:
anchors (list): anchors conv_block_num (int): number of conv block in each detection block
anchor_masks (list): anchor masks norm_decay (float): weight decay for normalization layer weights
nms (object): an instance of `MultiClassNMS` num_classes (int): number of output classes
""" anchors (list): anchors
__inject__ = ['yolo_loss', 'nms'] anchor_masks (list): anchor masks
__shared__ = ['num_classes', 'weight_prefix_name'] nms (object): an instance of `MultiClassNMS`
"""
def __init__(self, __inject__ = ['yolo_loss', 'nms']
norm_decay=0., __shared__ = ['num_classes', 'weight_prefix_name']
num_classes=80,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], def __init__(self,
[59, 119], [116, 90], [156, 198], [373, 326]], conv_block_num=2,
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], norm_decay=0.,
drop_block=False, num_classes=80,
iou_aware=False, anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
iou_aware_factor=0.4, [59, 119], [116, 90], [156, 198], [373, 326]],
block_size=3, anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
keep_prob=0.9, drop_block=False,
yolo_loss="YOLOv3Loss", coord_conv=False,
nms=MultiClassNMS( iou_aware=False,
score_threshold=0.01, iou_aware_factor=0.4,
nms_top_k=1000, block_size=3,
keep_top_k=100, keep_prob=0.9,
nms_threshold=0.45, yolo_loss="YOLOv3Loss",
background_label=-1).__dict__, spp=False,
weight_prefix_name='', nms=MultiClassNMS(
downsample=[32, 16, 8], score_threshold=0.01,
scale_x_y=1.0, nms_top_k=1000,
clip_bbox=True): keep_top_k=100,
check_version('2.0.0') nms_threshold=0.45,
self.norm_decay = norm_decay background_label=-1).__dict__,
self.num_classes = num_classes weight_prefix_name='',
self.anchor_masks = anchor_masks downsample=[32, 16, 8],
self._parse_anchors(anchors) scale_x_y=1.0,
self.yolo_loss = yolo_loss clip_bbox=True):
self.nms = nms self.conv_block_num = conv_block_num
self.prefix_name = weight_prefix_name self.norm_decay = norm_decay
self.drop_block = drop_block self.num_classes = num_classes
self.iou_aware = iou_aware self.anchor_masks = anchor_masks
self.iou_aware_factor = iou_aware_factor self._parse_anchors(anchors)
self.block_size = block_size self.yolo_loss = yolo_loss
self.keep_prob = keep_prob self.nms = nms
if isinstance(nms, dict): self.prefix_name = weight_prefix_name
self.nms = MultiClassNMS(**nms) self.drop_block = drop_block
self.downsample = downsample self.iou_aware = iou_aware
self.scale_x_y = scale_x_y self.coord_conv = coord_conv
self.clip_bbox = clip_bbox self.iou_aware_factor = iou_aware_factor
self.block_size = block_size
def _conv_bn(self, self.keep_prob = keep_prob
input, self.use_spp = spp
ch_out, if isinstance(nms, dict):
filter_size, self.nms = MultiClassNMS(**nms)
stride, self.downsample = downsample
padding, self.scale_x_y = scale_x_y
act='leaky', self.clip_bbox = clip_bbox
is_test=True,
name=None): def _create_tensor_from_numpy(self, numpy_array):
conv = fluid.layers.conv2d( paddle_array = fluid.layers.create_global_var(
input=input, shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
num_filters=ch_out, fluid.layers.assign(numpy_array, paddle_array)
filter_size=filter_size, return paddle_array
stride=stride,
padding=padding, def _add_coord(self, input, is_test=True):
act=None, if not self.coord_conv:
param_attr=ParamAttr(name=name + ".conv.weights"), return input
bias_attr=False)
# NOTE: here is used for exporting model for TensorRT inference,
bn_name = name + ".bn" # only support batch_size=1 for input shape should be fixed,
bn_param_attr = ParamAttr( # and we create tensor with fixed shape from numpy array
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale') if is_test and input.shape[2] > 0 and input.shape[3] > 0:
bn_bias_attr = ParamAttr( batch_size = 1
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset') grid_x = int(input.shape[3])
out = fluid.layers.batch_norm( grid_y = int(input.shape[2])
input=conv, idx_i = np.array(
act=None, [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
param_attr=bn_param_attr, dtype='float32')
bias_attr=bn_bias_attr, gi_np = np.repeat(idx_i, grid_y, axis=0)
moving_mean_name=bn_name + '.mean', gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
moving_variance_name=bn_name + '.var') gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
if act == 'leaky': x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
out = fluid.layers.leaky_relu(x=out, alpha=0.1) x_range.stop_gradient = True
return out y_range = self._create_tensor_from_numpy(
gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
def _detection_block(self, input, channel, is_test=True, name=None): y_range.stop_gradient = True
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \ # NOTE: in training mode, H and W is variable for random shape,
.format(channel, name) # implement add_coord with shape as Variable
else:
conv = input input_shape = fluid.layers.shape(input)
for j in range(2): b = input_shape[0]
conv = self._conv_bn( h = input_shape[2]
conv, w = input_shape[3]
channel,
filter_size=1, x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
stride=1, x_range = x_range - 1.
padding=0, x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
is_test=is_test, x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
name='{}.{}.0'.format(name, j)) x_range.stop_gradient = True
conv = self._conv_bn( y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
conv, y_range.stop_gradient = True
channel * 2,
filter_size=3, return fluid.layers.concat([input, x_range, y_range], axis=1)
stride=1,
padding=1, def _conv_bn(self,
is_test=is_test, input,
name='{}.{}.1'.format(name, j)) ch_out,
if self.drop_block and j == 0 and channel != 512: filter_size,
conv = DropBlock( stride,
conv, padding,
block_size=self.block_size, act='leaky',
keep_prob=self.keep_prob, is_test=True,
is_test=is_test) name=None):
conv = fluid.layers.conv2d(
if self.drop_block and channel == 512: input=input,
conv = DropBlock( num_filters=ch_out,
conv, filter_size=filter_size,
block_size=self.block_size, stride=stride,
keep_prob=self.keep_prob, padding=padding,
is_test=is_test) act=None,
route = self._conv_bn( param_attr=ParamAttr(name=name + ".conv.weights"),
conv, bias_attr=False)
channel,
filter_size=1, bn_name = name + ".bn"
stride=1, bn_param_attr = ParamAttr(
padding=0, regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
is_test=is_test, bn_bias_attr = ParamAttr(
name='{}.2'.format(name)) regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
tip = self._conv_bn( out = fluid.layers.batch_norm(
route, input=conv,
channel * 2, act=None,
filter_size=3, is_test=is_test,
stride=1, param_attr=bn_param_attr,
padding=1, bias_attr=bn_bias_attr,
is_test=is_test, moving_mean_name=bn_name + '.mean',
name='{}.tip'.format(name)) moving_variance_name=bn_name + '.var')
return route, tip
if act == 'leaky':
def _upsample(self, input, scale=2, name=None): out = fluid.layers.leaky_relu(x=out, alpha=0.1)
out = fluid.layers.resize_nearest( return out
input=input, scale=float(scale), name=name)
return out def _spp_module(self, input, is_test=True, name=""):
output1 = input
def _parse_anchors(self, anchors): output2 = fluid.layers.pool2d(
""" input=output1,
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors pool_size=5,
pool_stride=1,
""" pool_padding=2,
self.anchors = [] ceil_mode=False,
self.mask_anchors = [] pool_type='max')
output3 = fluid.layers.pool2d(
assert len(anchors) > 0, "ANCHORS not set." input=output1,
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set." pool_size=9,
pool_stride=1,
for anchor in anchors: pool_padding=4,
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor) ceil_mode=False,
self.anchors.extend(anchor) pool_type='max')
output4 = fluid.layers.pool2d(
anchor_num = len(anchors) input=output1,
for masks in self.anchor_masks: pool_size=13,
self.mask_anchors.append([]) pool_stride=1,
for mask in masks: pool_padding=6,
assert mask < anchor_num, "anchor mask index overflow" ceil_mode=False,
self.mask_anchors[-1].extend(anchors[mask]) pool_type='max')
output = fluid.layers.concat(
def _get_outputs(self, input, is_train=True): input=[output1, output2, output3, output4], axis=1)
""" return output
Get YOLOv3 head output
def _detection_block(self,
Args: input,
input (list): List of Variables, output of backbone stages channel,
is_train (bool): whether in train or test mode conv_block_num=2,
is_first=False,
Returns: is_test=True,
outputs (list): Variables of each output layer name=None):
""" assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
outputs = [] .format(channel, name)
# get last out_layer_num blocks in reverse order conv = input
out_layer_num = len(self.anchor_masks) for j in range(conv_block_num):
blocks = input[-1:-out_layer_num - 1:-1] conv = self._add_coord(conv, is_test=is_test)
conv = self._conv_bn(
route = None conv,
for i, block in enumerate(blocks): channel,
if i > 0: # perform concat in first 2 detection_block filter_size=1,
block = fluid.layers.concat(input=[route, block], axis=1) stride=1,
route, tip = self._detection_block( padding=0,
block, is_test=is_test,
channel=512 // (2**i), name='{}.{}.0'.format(name, j))
is_test=(not is_train), if self.use_spp and is_first and j == 1:
name=self.prefix_name + "yolo_block.{}".format(i)) conv = self._spp_module(conv, is_test=is_test, name="spp")
conv = self._conv_bn(
# out channel number = mask_num * (5 + class_num) conv,
if self.iou_aware: 512,
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) filter_size=1,
else: stride=1,
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) padding=0,
with fluid.name_scope('yolo_output'): is_test=is_test,
block_out = fluid.layers.conv2d( name='{}.{}.spp.conv'.format(name, j))
input=tip, conv = self._conv_bn(
num_filters=num_filters, conv,
filter_size=1, channel * 2,
stride=1, filter_size=3,
padding=0, stride=1,
act=None, padding=1,
param_attr=ParamAttr( is_test=is_test,
name=self.prefix_name + name='{}.{}.1'.format(name, j))
"yolo_output.{}.conv.weights".format(i)), if self.drop_block and j == 0 and not is_first:
bias_attr=ParamAttr( conv = DropBlock(
regularizer=L2Decay(0.), conv,
name=self.prefix_name + block_size=self.block_size,
"yolo_output.{}.conv.bias".format(i))) keep_prob=self.keep_prob,
outputs.append(block_out) is_test=is_test)
if i < len(blocks) - 1: if self.drop_block and is_first:
# do not perform upsample in the last detection_block conv = DropBlock(
route = self._conv_bn( conv,
input=route, block_size=self.block_size,
ch_out=256 // (2**i), keep_prob=self.keep_prob,
filter_size=1, is_test=is_test)
stride=1, conv = self._add_coord(conv, is_test=is_test)
padding=0, route = self._conv_bn(
is_test=(not is_train), conv,
name=self.prefix_name + "yolo_transition.{}".format(i)) channel,
# upsample filter_size=1,
route = self._upsample(route) stride=1,
padding=0,
return outputs is_test=is_test,
name='{}.2'.format(name))
def get_loss(self, input, gt_box, gt_label, gt_score, targets): new_route = self._add_coord(route, is_test=is_test)
""" tip = self._conv_bn(
Get final loss of network of YOLOv3. new_route,
channel * 2,
Args: filter_size=3,
input (list): List of Variables, output of backbone stages stride=1,
gt_box (Variable): The ground-truth boudding boxes. padding=1,
gt_label (Variable): The ground-truth class labels. is_test=is_test,
gt_score (Variable): The ground-truth boudding boxes mixup scores. name='{}.tip'.format(name))
targets ([Variables]): List of Variables, the targets for yolo return route, tip
loss calculatation.
def _upsample(self, input, scale=2, name=None):
Returns: out = fluid.layers.resize_nearest(
loss (Variable): The loss Variable of YOLOv3 network. input=input, scale=float(scale), name=name)
return out
"""
outputs = self._get_outputs(input, is_train=True) def _parse_anchors(self, anchors):
"""
return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets, Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
self.anchors, self.anchor_masks,
self.mask_anchors, self.num_classes, """
self.prefix_name) self.anchors = []
self.mask_anchors = []
def get_prediction(self, input, im_size):
""" assert len(anchors) > 0, "ANCHORS not set."
Get prediction result of YOLOv3 network assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
Args: for anchor in anchors:
input (list): List of Variables, output of backbone stages assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
im_size (Variable): Variable of size([h, w]) of each image self.anchors.extend(anchor)
Returns: anchor_num = len(anchors)
pred (Variable): The prediction result after non-max suppress. for masks in self.anchor_masks:
self.mask_anchors.append([])
""" for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
outputs = self._get_outputs(input, is_train=False) self.mask_anchors[-1].extend(anchors[mask])
boxes = [] def _get_outputs(self, input, is_train=True):
scores = [] """
for i, output in enumerate(outputs): Get YOLOv3 head output
if self.iou_aware:
output = get_iou_aware_score(output, Args:
len(self.anchor_masks[i]), input (list): List of Variables, output of backbone stages
self.num_classes, is_train (bool): whether in train or test mode
self.iou_aware_factor)
scale_x_y = self.scale_x_y if not isinstance( Returns:
self.scale_x_y, Sequence) else self.scale_x_y[i] outputs (list): Variables of each output layer
box, score = fluid.layers.yolo_box( """
x=output,
img_size=im_size, outputs = []
anchors=self.mask_anchors[i],
class_num=self.num_classes, # get last out_layer_num blocks in reverse order
conf_thresh=self.nms.score_threshold, out_layer_num = len(self.anchor_masks)
downsample_ratio=self.downsample[i], blocks = input[-1:-out_layer_num - 1:-1]
name=self.prefix_name + "yolo_box" + str(i),
clip_bbox=self.clip_bbox, route = None
scale_x_y=scale_x_y) for i, block in enumerate(blocks):
boxes.append(box) if i > 0: # perform concat in first 2 detection_block
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
yolo_boxes = fluid.layers.concat(boxes, axis=1) block,
yolo_scores = fluid.layers.concat(scores, axis=2) channel=64 * (2**out_layer_num) // (2**i),
if type(self.nms) is MultiClassSoftNMS: is_first=i == 0,
yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1]) is_test=(not is_train),
pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores) conv_block_num=self.conv_block_num,
return {'bbox': pred} name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
@register if self.iou_aware:
class YOLOv4Head(YOLOv3Head): num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
""" else:
Head block for YOLOv4 network num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
with fluid.name_scope('yolo_output'):
Args: block_out = fluid.layers.conv2d(
anchors (list): anchors input=tip,
anchor_masks (list): anchor masks num_filters=num_filters,
nms (object): an instance of `MultiClassNMS` filter_size=1,
spp_stage (int): apply spp on which stage. stride=1,
num_classes (int): number of output classes padding=0,
downsample (list): downsample ratio for each yolo_head act=None,
scale_x_y (list): scale the center point of bbox at each stage param_attr=ParamAttr(
""" name=self.prefix_name +
__inject__ = ['nms', 'yolo_loss'] "yolo_output.{}.conv.weights".format(i)),
__shared__ = ['num_classes', 'weight_prefix_name'] bias_attr=ParamAttr(
regularizer=L2Decay(0.),
def __init__(self, name=self.prefix_name +
anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], "yolo_output.{}.conv.bias".format(i)))
[72, 146], [142, 110], [192, 243], [459, 401]], outputs.append(block_out)
anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
nms=MultiClassNMS( if i < len(blocks) - 1:
score_threshold=0.01, # do not perform upsample in the last detection_block
nms_top_k=-1, route = self._conv_bn(
keep_top_k=-1, input=route,
nms_threshold=0.45, ch_out=256 // (2**i),
background_label=-1).__dict__, filter_size=1,
spp_stage=5, stride=1,
num_classes=80, padding=0,
weight_prefix_name='', is_test=(not is_train),
downsample=[8, 16, 32], name=self.prefix_name + "yolo_transition.{}".format(i))
scale_x_y=1.0, # upsample
yolo_loss="YOLOv3Loss", route = self._upsample(route)
iou_aware=False,
iou_aware_factor=0.4, return outputs
clip_bbox=False):
super(YOLOv4Head, self).__init__( def get_loss(self, input, gt_box, gt_label, gt_score, targets):
anchors=anchors, """
anchor_masks=anchor_masks, Get final loss of network of YOLOv3.
nms=nms,
num_classes=num_classes, Args:
weight_prefix_name=weight_prefix_name, input (list): List of Variables, output of backbone stages
downsample=downsample, gt_box (Variable): The ground-truth boudding boxes.
scale_x_y=scale_x_y, gt_label (Variable): The ground-truth class labels.
yolo_loss=yolo_loss, gt_score (Variable): The ground-truth boudding boxes mixup scores.
iou_aware=iou_aware, targets ([Variables]): List of Variables, the targets for yolo
iou_aware_factor=iou_aware_factor, loss calculatation.
clip_bbox=clip_bbox)
self.spp_stage = spp_stage Returns:
loss (Variable): The loss Variable of YOLOv3 network.
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest( """
input=input, scale=float(scale), name=name) outputs = self._get_outputs(input, is_train=True)
return out
return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets,
def max_pool(self, input, size): self.anchors, self.anchor_masks,
pad = [(size - 1) // 2] * 2 self.mask_anchors, self.num_classes,
return fluid.layers.pool2d(input, size, 'max', pool_padding=pad) self.prefix_name)
def spp(self, input): def get_prediction(self, input, im_size, exclude_nms=False):
branch_a = self.max_pool(input, 13) """
branch_b = self.max_pool(input, 9) Get prediction result of YOLOv3 network
branch_c = self.max_pool(input, 5)
out = fluid.layers.concat([branch_a, branch_b, branch_c, input], axis=1) Args:
return out input (list): List of Variables, output of backbone stages
im_size (Variable): Variable of size([h, w]) of each image
def stack_conv(self,
input, Returns:
ch_list=[512, 1024, 512], pred (Variable): The prediction result after non-max suppress.
filter_list=[1, 3, 1],
stride=1, """
name=None):
conv = input outputs = self._get_outputs(input, is_train=False)
for i, (ch_out, f_size) in enumerate(zip(ch_list, filter_list)):
padding = 1 if f_size == 3 else 0 boxes = []
conv = self._conv_bn( scores = []
conv, for i, output in enumerate(outputs):
ch_out=ch_out, if self.iou_aware:
filter_size=f_size, output = get_iou_aware_score(output,
stride=stride, len(self.anchor_masks[i]),
padding=padding, self.num_classes,
name='{}.{}'.format(name, i)) self.iou_aware_factor)
return conv scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
def spp_module(self, input, name=None): box, score = fluid.layers.yolo_box(
conv = self.stack_conv(input, name=name + '.stack_conv.0') x=output,
spp_out = self.spp(conv) img_size=im_size,
conv = self.stack_conv(spp_out, name=name + '.stack_conv.1') anchors=self.mask_anchors[i],
return conv class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
def pan_module(self, input, filter_list, name=None): downsample_ratio=self.downsample[i],
for i in range(1, len(input)): name=self.prefix_name + "yolo_box" + str(i),
ch_out = input[i].shape[1] // 2 clip_bbox=self.clip_bbox,
conv_left = self._conv_bn( scale_x_y=scale_x_y)
input[i], boxes.append(box)
ch_out=ch_out, scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
filter_size=1,
stride=1, yolo_boxes = fluid.layers.concat(boxes, axis=1)
padding=0, yolo_scores = fluid.layers.concat(scores, axis=2)
name=name + '.{}.left'.format(i))
ch_out = input[i - 1].shape[1] // 2 # Only for benchmark, postprocess(NMS) is not needed
conv_right = self._conv_bn( if exclude_nms:
input[i - 1], return {'bbox': yolo_scores}
ch_out=ch_out,
filter_size=1, if type(self.nms) is MultiClassSoftNMS:
stride=1, yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
padding=0, pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
name=name + '.{}.right'.format(i)) return {'bbox': pred}
conv_right = self._upsample(conv_right)
pan_out = fluid.layers.concat([conv_left, conv_right], axis=1)
ch_list = [pan_out.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] @register
input[i] = self.stack_conv( class YOLOv4Head(YOLOv3Head):
pan_out, """
ch_list=ch_list, Head block for YOLOv4 network
filter_list=filter_list,
name=name + '.stack_conv.{}'.format(i)) Args:
return input anchors (list): anchors
anchor_masks (list): anchor masks
def _get_outputs(self, input, is_train=True): nms (object): an instance of `MultiClassNMS`
outputs = [] spp_stage (int): apply spp on which stage.
filter_list = [1, 3, 1, 3, 1] num_classes (int): number of output classes
spp_stage = len(input) - self.spp_stage downsample (list): downsample ratio for each yolo_head
# get last out_layer_num blocks in reverse order scale_x_y (list): scale the center point of bbox at each stage
out_layer_num = len(self.anchor_masks) """
blocks = input[-1:-out_layer_num - 1:-1] __inject__ = ['nms', 'yolo_loss']
blocks[spp_stage] = self.spp_module( __shared__ = ['num_classes', 'weight_prefix_name']
blocks[spp_stage], name=self.prefix_name + "spp_module")
blocks = self.pan_module( def __init__(self,
blocks, anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
filter_list=filter_list, [72, 146], [142, 110], [192, 243], [459, 401]],
name=self.prefix_name + 'pan_module') anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
nms=MultiClassNMS(
# reverse order back to input score_threshold=0.01,
blocks = blocks[::-1] nms_top_k=-1,
keep_top_k=-1,
route = None nms_threshold=0.45,
for i, block in enumerate(blocks): background_label=-1).__dict__,
if i > 0: # perform concat in first 2 detection_block spp_stage=5,
route = self._conv_bn( num_classes=80,
route, weight_prefix_name='',
ch_out=route.shape[1] * 2, downsample=[8, 16, 32],
filter_size=3, scale_x_y=1.0,
stride=2, yolo_loss="YOLOv3Loss",
padding=1, iou_aware=False,
name=self.prefix_name + 'yolo_block.route.{}'.format(i)) iou_aware_factor=0.4,
block = fluid.layers.concat(input=[route, block], axis=1) clip_bbox=False):
ch_list = [block.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] super(YOLOv4Head, self).__init__(
block = self.stack_conv( anchors=anchors,
block, anchor_masks=anchor_masks,
ch_list=ch_list, nms=nms,
filter_list=filter_list, num_classes=num_classes,
name=self.prefix_name + weight_prefix_name=weight_prefix_name,
'yolo_block.stack_conv.{}'.format(i)) downsample=downsample,
route = block scale_x_y=scale_x_y,
yolo_loss=yolo_loss,
block_out = self._conv_bn( iou_aware=iou_aware,
block, iou_aware_factor=iou_aware_factor,
ch_out=block.shape[1] * 2, clip_bbox=clip_bbox)
filter_size=3, self.spp_stage = spp_stage
stride=1,
padding=1, def _upsample(self, input, scale=2, name=None):
name=self.prefix_name + 'yolo_output.{}.conv.0'.format(i)) out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
if self.iou_aware: return out
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
else: def max_pool(self, input, size):
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) pad = [(size - 1) // 2] * 2
block_out = fluid.layers.conv2d( return fluid.layers.pool2d(input, size, 'max', pool_padding=pad)
input=block_out,
num_filters=num_filters, def spp(self, input):
filter_size=1, branch_a = self.max_pool(input, 13)
stride=1, branch_b = self.max_pool(input, 9)
padding=0, branch_c = self.max_pool(input, 5)
act=None, out = fluid.layers.concat([branch_a, branch_b, branch_c, input], axis=1)
param_attr=ParamAttr(name=self.prefix_name + return out
"yolo_output.{}.conv.1.weights".format(i)),
bias_attr=ParamAttr( def stack_conv(self,
regularizer=L2Decay(0.), input,
name=self.prefix_name + ch_list=[512, 1024, 512],
"yolo_output.{}.conv.1.bias".format(i))) filter_list=[1, 3, 1],
outputs.append(block_out) stride=1,
name=None):
return outputs conv = input
for i, (ch_out, f_size) in enumerate(zip(ch_list, filter_list)):
padding = 1 if f_size == 3 else 0
conv = self._conv_bn(
conv,
ch_out=ch_out,
filter_size=f_size,
stride=stride,
padding=padding,
name='{}.{}'.format(name, i))
return conv
def spp_module(self, input, name=None):
conv = self.stack_conv(input, name=name + '.stack_conv.0')
spp_out = self.spp(conv)
conv = self.stack_conv(spp_out, name=name + '.stack_conv.1')
return conv
def pan_module(self, input, filter_list, name=None):
for i in range(1, len(input)):
ch_out = input[i].shape[1] // 2
conv_left = self._conv_bn(
input[i],
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
name=name + '.{}.left'.format(i))
ch_out = input[i - 1].shape[1] // 2
conv_right = self._conv_bn(
input[i - 1],
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
name=name + '.{}.right'.format(i))
conv_right = self._upsample(conv_right)
pan_out = fluid.layers.concat([conv_left, conv_right], axis=1)
ch_list = [pan_out.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]]
input[i] = self.stack_conv(
pan_out,
ch_list=ch_list,
filter_list=filter_list,
name=name + '.stack_conv.{}'.format(i))
return input
def _get_outputs(self, input, is_train=True):
outputs = []
filter_list = [1, 3, 1, 3, 1]
spp_stage = len(input) - self.spp_stage
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
blocks = input[-1:-out_layer_num - 1:-1]
blocks[spp_stage] = self.spp_module(
blocks[spp_stage], name=self.prefix_name + "spp_module")
blocks = self.pan_module(
blocks,
filter_list=filter_list,
name=self.prefix_name + 'pan_module')
# reverse order back to input
blocks = blocks[::-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
route = self._conv_bn(
route,
ch_out=route.shape[1] * 2,
filter_size=3,
stride=2,
padding=1,
name=self.prefix_name + 'yolo_block.route.{}'.format(i))
block = fluid.layers.concat(input=[route, block], axis=1)
ch_list = [block.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]]
block = self.stack_conv(
block,
ch_list=ch_list,
filter_list=filter_list,
name=self.prefix_name +
'yolo_block.stack_conv.{}'.format(i))
route = block
block_out = self._conv_bn(
block,
ch_out=block.shape[1] * 2,
filter_size=3,
stride=1,
padding=1,
name=self.prefix_name + 'yolo_output.{}.conv.0'.format(i))
if self.iou_aware:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
else:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=block_out,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name +
"yolo_output.{}.conv.1.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.1.bias".format(i)))
outputs.append(block_out)
return outputs
...@@ -251,7 +251,9 @@ class BlazeFace(object): ...@@ -251,7 +251,9 @@ class BlazeFace(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'eval') return self.build(feed_vars, 'eval')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def is_bbox_normalized(self): def is_bbox_normalized(self):
......
...@@ -434,5 +434,7 @@ class CascadeMaskRCNN(object): ...@@ -434,5 +434,7 @@ class CascadeMaskRCNN(object):
return self.build_multi_scale(feed_vars, mask_branch) return self.build_multi_scale(feed_vars, mask_branch)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
return self.build(feed_vars, 'test') assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test', exclude_nms=exclude_nms)
...@@ -331,5 +331,7 @@ class CascadeRCNN(object): ...@@ -331,5 +331,7 @@ class CascadeRCNN(object):
return self.build_multi_scale(feed_vars) return self.build_multi_scale(feed_vars)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -320,4 +320,6 @@ class CascadeRCNNClsAware(object): ...@@ -320,4 +320,6 @@ class CascadeRCNNClsAware(object):
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -136,5 +136,7 @@ class CornerNetSqueeze(object): ...@@ -136,5 +136,7 @@ class CornerNetSqueeze(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, mode='test') return self.build(feed_vars, mode='test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, mode='test') return self.build(feed_vars, mode='test')
...@@ -146,5 +146,7 @@ class EfficientDet(object): ...@@ -146,5 +146,7 @@ class EfficientDet(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -183,7 +183,9 @@ class FaceBoxes(object): ...@@ -183,7 +183,9 @@ class FaceBoxes(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'eval') return self.build(feed_vars, 'eval')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def is_bbox_normalized(self): def is_bbox_normalized(self):
......
...@@ -244,5 +244,7 @@ class FasterRCNN(object): ...@@ -244,5 +244,7 @@ class FasterRCNN(object):
return self.build_multi_scale(feed_vars) return self.build_multi_scale(feed_vars)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -179,5 +179,7 @@ class FCOS(object): ...@@ -179,5 +179,7 @@ class FCOS(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -337,5 +337,7 @@ class MaskRCNN(object): ...@@ -337,5 +337,7 @@ class MaskRCNN(object):
return self.build_multi_scale(feed_vars, mask_branch) return self.build_multi_scale(feed_vars, mask_branch)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -125,5 +125,7 @@ class RetinaNet(object): ...@@ -125,5 +125,7 @@ class RetinaNet(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
...@@ -134,7 +134,9 @@ class SSD(object): ...@@ -134,7 +134,9 @@ class SSD(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, 'eval') return self.build(feed_vars, 'eval')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
assert not exclude_nms, "exclude_nms for {} is not support currently".format(
self.__class__.__name__)
return self.build(feed_vars, 'test') return self.build(feed_vars, 'test')
def is_bbox_normalized(self): def is_bbox_normalized(self):
......
...@@ -49,7 +49,7 @@ class YOLOv3(object): ...@@ -49,7 +49,7 @@ class YOLOv3(object):
self.yolo_head = yolo_head self.yolo_head = yolo_head
self.use_fine_grained_loss = use_fine_grained_loss self.use_fine_grained_loss = use_fine_grained_loss
def build(self, feed_vars, mode='train'): def build(self, feed_vars, mode='train', exclude_nms=False):
im = feed_vars['image'] im = feed_vars['image']
mixed_precision_enabled = mixed_precision_global_state() is not None mixed_precision_enabled = mixed_precision_global_state() is not None
...@@ -74,9 +74,9 @@ class YOLOv3(object): ...@@ -74,9 +74,9 @@ class YOLOv3(object):
gt_score = feed_vars['gt_score'] gt_score = feed_vars['gt_score']
# Get targets for splited yolo loss calculation # Get targets for splited yolo loss calculation
# YOLOv3 supports up to 3 output layers currently num_output_layer = len(self.yolo_head.anchor_masks)
targets = [] targets = []
for i in range(3): for i in range(num_output_layer):
k = 'target{}'.format(i) k = 'target{}'.format(i)
if k in feed_vars: if k in feed_vars:
targets.append(feed_vars[k]) targets.append(feed_vars[k])
...@@ -88,7 +88,9 @@ class YOLOv3(object): ...@@ -88,7 +88,9 @@ class YOLOv3(object):
return loss return loss
else: else:
im_size = feed_vars['im_size'] im_size = feed_vars['im_size']
return self.yolo_head.get_prediction(body_feats, im_size) # exclude_nms only for benchmark, postprocess(NMS) is not needed
return self.yolo_head.get_prediction(
body_feats, im_size, exclude_nms=exclude_nms)
def _inputs_def(self, image_shape, num_max_boxes): def _inputs_def(self, image_shape, num_max_boxes):
im_shape = [None] + image_shape im_shape = [None] + image_shape
...@@ -106,11 +108,10 @@ class YOLOv3(object): ...@@ -106,11 +108,10 @@ class YOLOv3(object):
if self.use_fine_grained_loss: if self.use_fine_grained_loss:
# yapf: disable # yapf: disable
targets_def = { num_output_layer = len(self.yolo_head.anchor_masks)
'target0': {'shape': [None, 3, 86, 19, 19], 'dtype': 'float32', 'lod_level': 0}, targets_def = {}
'target1': {'shape': [None, 3, 86, 38, 38], 'dtype': 'float32', 'lod_level': 0}, for i in range(num_output_layer):
'target2': {'shape': [None, 3, 86, 76, 76], 'dtype': 'float32', 'lod_level': 0}, targets_def['target{}'.format(i)] = {'shape': [None, 3, None, None, None], 'dtype': 'float32', 'lod_level': 0}
}
# yapf: enable # yapf: enable
downsample = 32 downsample = 32
...@@ -139,7 +140,9 @@ class YOLOv3(object): ...@@ -139,7 +140,9 @@ class YOLOv3(object):
# will be disabled for YOLOv3 architecture do not calculate loss in # will be disabled for YOLOv3 architecture do not calculate loss in
# eval/infer mode. # eval/infer mode.
if 'im_size' not in fields and self.use_fine_grained_loss: if 'im_size' not in fields and self.use_fine_grained_loss:
fields.extend(['target0', 'target1', 'target2']) num_output_layer = len(self.yolo_head.anchor_masks)
fields.extend(
['target{}'.format(i) for i in range(num_output_layer)])
feed_vars = OrderedDict([(key, fluid.data( feed_vars = OrderedDict([(key, fluid.data(
name=key, name=key,
shape=inputs_def[key]['shape'], shape=inputs_def[key]['shape'],
...@@ -158,8 +161,8 @@ class YOLOv3(object): ...@@ -158,8 +161,8 @@ class YOLOv3(object):
def eval(self, feed_vars): def eval(self, feed_vars):
return self.build(feed_vars, mode='test') return self.build(feed_vars, mode='test')
def test(self, feed_vars): def test(self, feed_vars, exclude_nms=False):
return self.build(feed_vars, mode='test') return self.build(feed_vars, mode='test', exclude_nms=exclude_nms)
@register @register
......
...@@ -54,6 +54,7 @@ class IouAwareLoss(IouLoss): ...@@ -54,6 +54,7 @@ class IouAwareLoss(IouLoss):
anchors, anchors,
downsample_ratio, downsample_ratio,
batch_size, batch_size,
scale_x_y,
eps=1.e-10): eps=1.e-10):
''' '''
Args: Args:
...@@ -67,9 +68,9 @@ class IouAwareLoss(IouLoss): ...@@ -67,9 +68,9 @@ class IouAwareLoss(IouLoss):
''' '''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio, pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False) batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio, gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True) batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps) iouk = self._iou(pred, gt, ioup, eps)
iouk.stop_gradient = True iouk.stop_gradient = True
......
...@@ -63,6 +63,7 @@ class IouLoss(object): ...@@ -63,6 +63,7 @@ class IouLoss(object):
anchors, anchors,
downsample_ratio, downsample_ratio,
batch_size, batch_size,
scale_x_y=1.,
ioup=None, ioup=None,
eps=1.e-10): eps=1.e-10):
''' '''
...@@ -75,9 +76,9 @@ class IouLoss(object): ...@@ -75,9 +76,9 @@ class IouLoss(object):
eps (float): the decimal to prevent the denominator eqaul zero eps (float): the decimal to prevent the denominator eqaul zero
''' '''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio, pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False) batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio, gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True) batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps) iouk = self._iou(pred, gt, ioup, eps)
if self.loss_square: if self.loss_square:
loss_iou = 1. - iouk * iouk loss_iou = 1. - iouk * iouk
...@@ -145,7 +146,7 @@ class IouLoss(object): ...@@ -145,7 +146,7 @@ class IouLoss(object):
return diou_term + ciou_term return diou_term + ciou_term
def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio, def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
batch_size, is_gt): batch_size, is_gt, scale_x_y, eps):
grid_x = int(self._MAX_WI / downsample_ratio) grid_x = int(self._MAX_WI / downsample_ratio)
grid_y = int(self._MAX_HI / downsample_ratio) grid_y = int(self._MAX_HI / downsample_ratio)
an_num = len(anchors) // 2 an_num = len(anchors) // 2
...@@ -179,8 +180,11 @@ class IouLoss(object): ...@@ -179,8 +180,11 @@ class IouLoss(object):
cy.gradient = True cy.gradient = True
else: else:
dcx_sig = fluid.layers.sigmoid(dcx) dcx_sig = fluid.layers.sigmoid(dcx)
cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
dcy_sig = fluid.layers.sigmoid(dcy) dcy_sig = fluid.layers.sigmoid(dcy)
if (abs(scale_x_y - 1.0) > eps):
dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1)
dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1)
cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0] anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
......
...@@ -92,7 +92,7 @@ class YOLOv3Loss(object): ...@@ -92,7 +92,7 @@ class YOLOv3Loss(object):
return {'loss': sum(losses)} return {'loss': sum(losses)}
def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size, def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size,
num_classes, mask_anchors, ignore_thresh): num_classes, mask_anchors, ignore_thresh, eps=1.e-10):
""" """
Calculate fine grained YOLOv3 loss Calculate fine grained YOLOv3 loss
...@@ -136,12 +136,25 @@ class YOLOv3Loss(object): ...@@ -136,12 +136,25 @@ class YOLOv3Loss(object):
tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target) tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
tscale_tobj = tscale * tobj tscale_tobj = tscale * tobj
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj scale_x_y = self.scale_x_y if not isinstance(
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) self.scale_x_y, Sequence) else self.scale_x_y[i]
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj if (abs(scale_x_y - 1.0) < eps):
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
else:
dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss # NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
...@@ -149,7 +162,7 @@ class YOLOv3Loss(object): ...@@ -149,7 +162,7 @@ class YOLOv3Loss(object):
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
if self._iou_loss is not None: if self._iou_loss is not None:
loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
downsample, self._batch_size) downsample, self._batch_size, scale_x_y)
loss_iou = loss_iou * tscale_tobj loss_iou = loss_iou * tscale_tobj
loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
loss_ious.append(fluid.layers.reduce_mean(loss_iou)) loss_ious.append(fluid.layers.reduce_mean(loss_iou))
...@@ -157,14 +170,12 @@ class YOLOv3Loss(object): ...@@ -157,14 +170,12 @@ class YOLOv3Loss(object):
if self._iou_aware_loss is not None: if self._iou_aware_loss is not None:
loss_iou_aware = self._iou_aware_loss( loss_iou_aware = self._iou_aware_loss(
ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample, ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
self._batch_size) self._batch_size, scale_x_y)
loss_iou_aware = loss_iou_aware * tobj loss_iou_aware = loss_iou_aware * tobj
loss_iou_aware = fluid.layers.reduce_sum( loss_iou_aware = fluid.layers.reduce_sum(
loss_iou_aware, dim=[1, 2, 3]) loss_iou_aware, dim=[1, 2, 3])
loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware))
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
loss_obj_pos, loss_obj_neg = self._calc_obj_loss( loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors, output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh, scale_x_y) num_classes, downsample, self._ignore_thresh, scale_x_y)
...@@ -293,7 +304,7 @@ class YOLOv3Loss(object): ...@@ -293,7 +304,7 @@ class YOLOv3Loss(object):
downsample_ratio=downsample, downsample_ratio=downsample,
clip_bbox=False, clip_bbox=False,
scale_x_y=scale_x_y) scale_x_y=scale_x_y)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample # and gt bbox in each sample
if batch_size > 1: if batch_size > 1:
...@@ -322,17 +333,17 @@ class YOLOv3Loss(object): ...@@ -322,17 +333,17 @@ class YOLOv3Loss(object):
pred = fluid.layers.squeeze(pred, axes=[0]) pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt)) ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0) iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox, # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss # Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1) max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
if self.match_score: if self.match_score:
max_prob = fluid.layers.reduce_max(prob, dim=-1) max_prob = fluid.layers.reduce_max(prob, dim=-1)
iou_mask = iou_mask * fluid.layers.cast( iou_mask = iou_mask * fluid.layers.cast(
max_prob <= 0.25, dtype="float32") max_prob <= 0.25, dtype="float32")
output_shape = fluid.layers.shape(output) output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2 an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
......
...@@ -526,7 +526,7 @@ class MatrixNMS(object): ...@@ -526,7 +526,7 @@ class MatrixNMS(object):
gaussian_sigma=2., gaussian_sigma=2.,
normalized=False, normalized=False,
background_label=0): background_label=0):
super(MultiClassNMS, self).__init__() super(MatrixNMS, self).__init__()
self.score_threshold = score_threshold self.score_threshold = score_threshold
self.post_threshold = post_threshold self.post_threshold = post_threshold
self.nms_top_k = nms_top_k self.nms_top_k = nms_top_k
......
...@@ -196,7 +196,8 @@ def main(): ...@@ -196,7 +196,8 @@ def main():
inputs_def = cfg['TestReader']['inputs_def'] inputs_def = cfg['TestReader']['inputs_def']
inputs_def['use_dataloader'] = False inputs_def['use_dataloader'] = False
feed_vars, _ = model.build_inputs(**inputs_def) feed_vars, _ = model.build_inputs(**inputs_def)
test_fetches = model.test(feed_vars) # postprocess not need in exclude_nms, exclude NMS in exclude_nms mode
test_fetches = model.test(feed_vars, exclude_nms=FLAGS.exclude_nms)
infer_prog = infer_prog.clone(True) infer_prog = infer_prog.clone(True)
check_py_func(infer_prog) check_py_func(infer_prog)
...@@ -214,6 +215,11 @@ if __name__ == '__main__': ...@@ -214,6 +215,11 @@ if __name__ == '__main__':
type=str, type=str,
default="output", default="output",
help="Directory for storing the output model files.") help="Directory for storing the output model files.")
parser.add_argument(
"--exclude_nms",
action='store_true',
default=False,
help="Whether prune NMS for benchmark")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册