未验证 提交 6b76b6fc 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]yolo series (#2148)

* add ppyolo r18vd mbv3, yolov3 r50vd

* modify TestReader of ppyolo mbv3 r18vd

* add clip to avoid nan, modify ema to apply ema on bn mean and bn var

* fix code resulting in nan

* add yolov3 r50vd dcn configs

* finish yolo_series and fix some problems

* hide --bias flag and modify docs

* modify --bias and fix deploy/python/infer
上级 83f11ba0
......@@ -37,9 +37,18 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
### PP-YOLO
| Model | GPU number | images/GPU | backbone | input shape | Box AP<sup>val</sup> | Box AP<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:------------------------:|:----------:|:----------:|:----------:| :----------:| :------------------: | :-------------------: | :------------: | :---------------------: | :------: | :-----: |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------------: | :------------: | :---------------------: | :------: | :------: |
| PP-YOLO | 8 | 24 | ResNet50vd | 608 | 44.8 | 45.2 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 512 | 43.9 | 44.4 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 416 | 42.1 | 42.5 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 320 | 38.9 | 39.3 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 608 | 45.3 | 45.9 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
**Notes:**
......@@ -49,6 +58,29 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
- PP-YOLO FP32 inference speed testing uses inference model exported by `tools/export_model.py` and benchmarked by running `depoly/python/infer.py` with `--run_benchmark`. All testing results do not contains the time cost of data reading and post-processing(NMS), which is same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) in testing method.
- TensorRT FP16 inference speed testing exclude the time cost of bounding-box decoding(`yolo_box`) part comparing with FP32 testing above, which means that data reading, bounding-box decoding and post-processing(NMS) is excluded(test method same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) too)
### PP-YOLO for mobile
| Model | GPU number | images/GPU | Model Size | input shape | Box AP<sup>val</sup> | Box AP50<sup>val</sup> | Kirin 990 1xCore(FPS) | download | config |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :--------------------: | :--------------------: | :------: | :------: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 28MB | 320 | 23.2 | 42.6 | 14.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_mbv3_large_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_mbv3_large_coco.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 16MB | 320 | 17.2 | 33.8 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_mbv3_small_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_mbv3_small_coco.yml) |
**Notes:**
- PP-YOLO_MobileNetV3 is trained on COCO train2017 datast and evaluated on val2017 dataset,Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5:0.95)`, Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5)`.
- PP-YOLO_MobileNetV3 used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../../docs/FAQ.md).
- PP-YOLO_MobileNetV3 inference speed is tested on Kirin 990 with 1 thread.
### PP-YOLO on Pascal VOC
PP-YOLO trained on Pascal VOC dataset as follows:
| Model | GPU number | images/GPU | backbone | input shape | Box AP50<sup>val</sup> | download | config |
|:------------------:|:----------:|:----------:|:----------:| :----------:| :--------------------: | :------: | :-----: |
| PP-YOLO | 8 | 12 | ResNet50vd | 608 | 84.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
| PP-YOLO | 8 | 12 | ResNet50vd | 416 | 84.3 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
| PP-YOLO | 8 | 12 | ResNet50vd | 320 | 82.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
## Getting Start
### 1. Training
......
......@@ -39,7 +39,16 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP<sup>val</sup> | Box AP<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------------: | :------------: | :---------------------: | :------: | :------: |
| PP-YOLO | 8 | 24 | ResNet50vd | 608 | 44.8 | 45.2 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 512 | 43.9 | 44.4 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 416 | 42.1 | 42.5 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO | 8 | 24 | ResNet50vd | 320 | 38.9 | 39.3 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 608 | 45.3 | 45.9 | 72.9 | 155.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_2x_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r18vd_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r18vd_coco.yml) |
**注意:**
......@@ -50,6 +59,27 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- TensorRT FP16的速度测试相比于FP32去除了`yolo_box`(bbox解码)部分耗时,即不包含数据预处理,bbox解码和NMS(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。
- PP-YOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。
### PP-YOLO 轻量级模型
| 模型 | GPU个数 | 每GPU图片个数 | 模型体积 | 输入尺寸 | Box AP<sup>val</sup> | Box AP50<sup>val</sup> | Kirin 990 1xCore (FPS) | 模型下载 | 配置文件 |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :--------------------: | :--------------------: | :------: | :------: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 28MB | 320 | 23.2 | 42.6 | 14.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_mbv3_large_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_mbv3_large_coco.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 16MB | 320 | 17.2 | 33.8 | 21.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_mbv3_small_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_mbv3_small_coco.yml) |
- PP-YOLO_MobileNetV3 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box AP<sup>val</sup>`mAP(IoU=0.5:0.95)`评估结果, Box AP50<sup>val</sup>`mAP(IoU=0.5)`评估结果。
- PP-YOLO_MobileNetV3 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO_MobileNetV3 模型推理速度测试环境配置为麒麟990芯片单线程。
### Pascal VOC数据集上的PP-YOLO
PP-YOLO在Pascal VOC数据集上训练模型如下:
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP50<sup>val</sup> | 模型下载 | 配置文件 |
|:------------------:|:-------:|:-------------:|:----------:| :----------:| :--------------------: | :------: | :-----: |
| PP-YOLO | 8 | 12 | ResNet50vd | 608 | 84.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
| PP-YOLO | 8 | 12 | ResNet50vd | 416 | 84.3 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
| PP-YOLO | 8 | 12 | ResNet50vd | 320 | 82.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_voc.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_r50vd_dcn_voc.yml) |
## 使用说明
### 1. 训练
......
......@@ -13,7 +13,6 @@ LearningRate:
steps: 4000
OptimizerBuilder:
clip_grad_by_norm: 35.
optimizer:
momentum: 0.9
type: Momentum
......
......@@ -13,7 +13,6 @@ LearningRate:
steps: 4000
OptimizerBuilder:
clip_grad_by_norm: 35.
optimizer:
momentum: 0.9
type: Momentum
......
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
load_static_weights: true
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
neck: PPYOLOFPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNetV3:
model_name: large
scale: 1.
with_extra_blocks: false
extra_block_filters: []
feature_maps: [13, 16]
PPYOLOFPN:
feat_channels: [160, 368]
coord_conv: true
conv_block_num: 0
spp: true
drop_block: true
YOLOv3Head:
anchors: [[11, 18], [34, 47], [51, 126],
[115, 71], [120, 195], [254, 235]]
anchor_masks: [[3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.5
downsample: [32, 16]
label_smooth: false
scale_x_y: 1.05
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
loss_square: true
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MultiClassNMS
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
score_threshold: 0.005
normalized: false
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar
load_static_weights: true
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
neck: PPYOLOFPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNetV3:
model_name: small
scale: 1.
with_extra_blocks: false
extra_block_filters: []
feature_maps: [9, 12]
PPYOLOFPN:
feat_channels: [96, 304]
coord_conv: true
conv_block_num: 0
spp: true
drop_block: true
YOLOv3Head:
anchors: [[11, 18], [34, 47], [51, 126],
[115, 71], [120, 195], [254, 235]]
anchor_masks: [[3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.5
downsample: [32, 16]
label_smooth: false
scale_x_y: 1.05
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
loss_square: true
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MultiClassNMS
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
score_threshold: 0.005
normalized: false
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
load_static_weights: true
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
neck: PPYOLOFPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
ResNet:
depth: 18
variant: d
return_idx: [2, 3]
freeze_at: -1
freeze_norm: false
norm_decay: 0.
PPYOLOFPN:
feat_channels: [512, 512]
drop_block: true
block_size: 3
keep_prob: 0.9
conv_block_num: 0
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16]
label_smooth: false
scale_x_y: 1.05
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
loss_square: true
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.01
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MatrixNMS
keep_top_k: 100
score_threshold: 0.01
post_threshold: 0.01
nms_top_k: -1
normalized: false
background_label: -1
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo_r50vd_dcn/model_final
load_static_weights: true
norm_type: sync_bn
use_ema: true
......@@ -55,7 +54,7 @@ IouAwareLoss:
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
conf_thresh: 0.01
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
......@@ -66,3 +65,4 @@ BBoxPostProcess:
post_threshold: 0.01
nms_top_k: -1
normalized: false
background_label: -1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_mbv3_large.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 10
weights: output/ppyolo_mbv3_large_coco/model_final
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp:
target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512]
random_size: True
random_interp: True
keep_ratio: False
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 90}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[11, 18], [34, 47], [51, 126], [115, 71], [120, 195], [254, 235]]
downsample_ratios: [32, 16]
iou_thresh: 0.25
num_classes: 80
batch_size: 32
mixup_epoch: 200
shuffle: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 320, 320]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
epoch: 270
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 162
- 216
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_mbv3_small.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 10
weights: output/ppyolo_mbv3_small_coco/model_final
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp:
target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512]
random_size: True
random_interp: True
keep_ratio: False
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 90}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[11, 18], [34, 47], [51, 126], [115, 71], [120, 195], [254, 235]]
downsample_ratios: [32, 16]
iou_thresh: 0.25
num_classes: 80
batch_size: 32
mixup_epoch: 200
shuffle: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 320, 320]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
epoch: 270
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 162
- 216
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_r18vd.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 10
weights: output/ppyolo_r18vd_coco/model_final
TrainReader:
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp:
target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_size: True
random_interp: True
keep_ratio: False
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
- PermuteOp: {}
- Gt2YoloTargetOp:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
mixup_epoch: 500
shuffle: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [512, 512], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 512, 512]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [512, 512], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
epoch: 270
LearningRate:
base_lr: 0.004
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 162
- 216
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
......@@ -7,3 +7,4 @@ _BASE_: [
]
snapshot_epoch: 16
weights: output/ppyolo_r50vd_dcn_1x_coco/model_final
......@@ -7,7 +7,8 @@ _BASE_: [
]
snapshot_epoch: 8
use_ema: false
use_ema: true
weights: output/ppyolo_r50vd_dcn_1x_minicoco/model_final
TrainReader:
batch_size: 12
......@@ -33,3 +34,11 @@ LearningRate:
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
......@@ -7,3 +7,4 @@ _BASE_: [
]
snapshot_epoch: 16
weights: output/ppyolo_r50vd_dcn_2x_coco/model_final
_BASE_: [
'../datasets/voc.yml',
'../runtime.yml',
'./_base_/ppyolo_r50vd_dcn.yml',
'./_base_/optimizer_1x.yml',
'./_base_/ppyolo_reader.yml',
]
snapshot_epoch: 83
weights: output/ppyolo_r50vd_dcn_voc/model_final
TrainReader:
batch_transforms:
- BatchRandomResizeOp: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8], num_classes: 20}
mixup_epoch: 350
batch_size: 12
epoch: 583
LearningRate:
base_lr: 0.00333
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 466
- 516
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
......@@ -12,6 +12,7 @@
| DarkNet53 | 608 | 8 | 270e | ---- | 39.0 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_darknet53_270e_coco.yml) |
| DarkNet53 | 416 | 8 | 270e | ---- | 37.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_darknet53_270e_coco.yml) |
| DarkNet53 | 320 | 8 | 270e | ---- | 34.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_darknet53_270e_coco.yml) |
| ResNet50_vd | 608 | 8 | 270e | ---- | 39.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_r50vd_dcn_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_r50vd_dcn_270e_coco.yml) |
| MobileNet-V1 | 608 | 8 | 270e | ---- | 28.8 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) |
| MobileNet-V1 | 416 | 8 | 270e | ---- | 28.7 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) |
| MobileNet-V1 | 320 | 8 | 270e | ---- | 26.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) |
......
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
use_fine_grained_loss: false
load_static_weights: True
norm_type: sync_bn
......
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
load_static_weights: True
norm_type: sync_bn
YOLOv3:
backbone: ResNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
ResNet:
depth: 50
variant: d
return_idx: [1, 2, 3]
dcn_v2_stages: [3]
freeze_at: -1
freeze_norm: false
norm_decay: 0.
# YOLOv3FPN:
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
......@@ -5,4 +5,6 @@ _BASE_: [
'_base_/yolov3_darknet53.yml',
'_base_/yolov3_reader.yml',
]
weights: output/yolov3_darknet53_coco/model_final
snapshot_epoch: 5
weights: output/yolov3_darknet53_270e_coco/model_final
......@@ -5,4 +5,6 @@ _BASE_: [
'_base_/yolov3_mobilenet_v1.yml',
'_base_/yolov3_reader.yml',
]
weights: output/yolov3_mobilenet_v1_coco/model_final
snapshot_epoch: 5
weights: output/yolov3_mobilenet_v1_270e_coco/model_final
......@@ -5,7 +5,9 @@ _BASE_: [
'_base_/yolov3_mobilenet_v1.yml',
'_base_/yolov3_reader.yml',
]
weights: output/yolov3_mobilenet_v1_voc/model_final
snapshot_epoch: 5
weights: output/yolov3_mobilenet_v1_270e_voc/model_final
TrainReader:
inputs_def:
......
......@@ -5,4 +5,6 @@ _BASE_: [
'_base_/yolov3_mobilenet_v3_large.yml',
'_base_/yolov3_reader.yml',
]
weights: output/yolov3_mobilenet_v3_large_coco/model_final
snapshot_epoch: 5
weights: output/yolov3_mobilenet_v3_large_270e_coco/model_final
......@@ -5,7 +5,9 @@ _BASE_: [
'_base_/yolov3_mobilenet_v3_large.yml',
'_base_/yolov3_reader.yml',
]
weights: output/yolov3_mobilenet_v3_large_voc/model_final
snapshot_epoch: 5
weights: output/yolov3_mobilenet_v3_large_270e_voc/model_final
TrainReader:
inputs_def:
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_270e.yml',
'_base_/yolov3_r50vd_dcn.yml',
'_base_/yolov3_reader.yml',
]
snapshot_epoch: 5
weights: output/yolov3_r50vd_dcn_270e_coco/model_final
......@@ -135,6 +135,13 @@ class Detector(object):
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
score_tensor = self.predictor.get_output_handle(output_names[3])
np_score = score_tensor.copy_to_cpu()
label_tensor = self.predictor.get_output_handle(output_names[2])
np_label = label_tensor.copy_to_cpu()
np_boxes = np.concatenate(
[np_label[:, np.newaxis], np_score[:, np.newaxis], np_boxes],
axis=-1)
if self.pred_config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu()
......
......@@ -290,7 +290,8 @@ class Gt2YoloTargetOp(BaseOperator):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
if iou > self.iou_thresh and target[idx, 5, gj,
gi] == 0.:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
......
......@@ -319,7 +319,8 @@ class Gt2YoloTarget(BaseOperator):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
if iou > self.iou_thresh and target[idx, 5, gj,
gi] == 0.:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
......
......@@ -114,7 +114,11 @@ class Trainer(object):
self._metrics = []
return
if self.cfg.metric == 'COCO':
self._metrics = [COCOMetric(anno_file=self.dataset.get_anno())]
# TODO: bias should be unified
self._metrics = [
COCOMetric(
anno_file=self.dataset.get_anno(), bias=self.cfg.bias)
]
elif self.cfg.metric == 'VOC':
self._metrics = [
VOCMetric(
......
......@@ -24,7 +24,7 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
def get_infer_results(outs, catid):
def get_infer_results(outs, catid, bias=0):
"""
Get result at the stage of inference.
The output format is dictionary containing bbox or mask result.
......@@ -41,9 +41,14 @@ def get_infer_results(outs, catid):
infer_res = {}
if 'bbox' in outs:
infer_res['bbox'] = get_det_res(outs['bbox'], outs['score'],
outs['label'], outs['bbox_num'], im_id,
catid)
infer_res['bbox'] = get_det_res(
outs['bbox'],
outs['score'],
outs['label'],
outs['bbox_num'],
im_id,
catid,
bias=bias)
if 'mask' in outs:
# mask post process
......
......@@ -49,12 +49,13 @@ class Metric(paddle.metric.Metric):
class COCOMetric(Metric):
def __init__(self, anno_file):
def __init__(self, anno_file, **kwargs):
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
self.anno_file = anno_file
self.clsid2catid, self.catid2name = get_categories('COCO', anno_file)
# TODO: bias should be unified
self.bias = kwargs.get('bias', 0)
self.reset()
def reset(self):
......@@ -72,7 +73,8 @@ class COCOMetric(Metric):
outs['im_id'] = im_id.numpy() if isinstance(im_id,
paddle.Tensor) else im_id
infer_results = get_infer_results(outs, self.clsid2catid)
infer_results = get_infer_results(
outs, self.clsid2catid, bias=self.bias)
self.results['bbox'] += infer_results[
'bbox'] if 'bbox' in infer_results else []
self.results['mask'] += infer_results[
......
......@@ -39,15 +39,16 @@ class YOLOv3Head(nn.Layer):
self.yolo_outputs = []
for i in range(len(self.anchors)):
if self.iou_aware:
num_filters = self.num_outputs * (self.num_classes + 6)
num_filters = len(self.anchors[i]) * (self.num_classes + 6)
else:
num_filters = self.num_outputs * (self.num_classes + 5)
num_filters = len(self.anchors[i]) * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i)
yolo_output = self.add_sublayer(
name,
nn.Conv2D(
in_channels=1024 // (2**i),
in_channels=128 * (2**self.num_outputs) // (2**i),
out_channels=num_filters,
kernel_size=1,
stride=1,
......
......@@ -249,6 +249,7 @@ class PPYOLOFPN(nn.Layer):
self.keep_prob = kwargs.get('keep_prob', 0.9)
self.spp = kwargs.get('spp', False)
self.conv_block_num = kwargs.get('conv_block_num', 2)
if self.coord_conv:
ConvLayer = CoordConv
else:
......@@ -269,25 +270,37 @@ class PPYOLOFPN(nn.Layer):
if i > 0:
ch_in += 512 // (2**i)
channel = 64 * (2**self.num_blocks) // (2**i)
base_cfg = [
# name of layer, Layer, args
['conv0', ConvLayer, [ch_in, channel, 1]],
['conv1', ConvBNLayer, [channel, channel * 2, 3]],
['conv2', ConvLayer, [channel * 2, channel, 1]],
['conv3', ConvBNLayer, [channel, channel * 2, 3]],
['route', ConvLayer, [channel * 2, channel, 1]],
['tip', ConvLayer, [channel, channel * 2, 3]]
base_cfg = []
c_in, c_out = ch_in, channel
for j in range(self.conv_block_num):
base_cfg += [
[
'conv{}'.format(2 * j), ConvLayer, [c_in, c_out, 1],
dict(
padding=0, norm_type=norm_type)
],
[
'conv{}'.format(2 * j + 1), ConvBNLayer,
[c_out, c_out * 2, 3], dict(
padding=1, norm_type=norm_type)
],
]
for conf in base_cfg:
filter_size = conf[-1][-1]
conf.append(dict(padding=filter_size // 2, norm_type=norm_type))
c_in, c_out = c_out * 2, c_out
base_cfg += [[
'route', ConvLayer, [c_in, c_out, 1], dict(
padding=0, norm_type=norm_type)
], [
'tip', ConvLayer, [c_out, c_out * 2, 3], dict(
padding=1, norm_type=norm_type)
]]
if self.conv_block_num == 2:
if i == 0:
if self.spp:
pool_size = [5, 9, 13]
spp_cfg = [[
'spp', SPP,
[channel * (len(pool_size) + 1), channel, 1], dict(
pool_size=pool_size, norm_type=norm_type)
'spp', SPP, [channel * 4, channel, 1], dict(
pool_size=[5, 9, 13], norm_type=norm_type)
]]
else:
spp_cfg = []
......@@ -295,6 +308,15 @@ class PPYOLOFPN(nn.Layer):
3:4] + dropblock_cfg + base_cfg[4:6]
else:
cfg = base_cfg[0:2] + dropblock_cfg + base_cfg[2:6]
elif self.conv_block_num == 0:
if self.spp and i == 0:
spp_cfg = [[
'spp', SPP, [c_in * 4, c_in, 1], dict(
pool_size=[5, 9, 13], norm_type=norm_type)
]]
else:
spp_cfg = []
cfg = spp_cfg + dropblock_cfg + base_cfg
name = 'yolo_block.{}'.format(i)
yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name))
self.yolo_blocks.append(yolo_block)
......@@ -305,7 +327,7 @@ class PPYOLOFPN(nn.Layer):
name,
ConvBNLayer(
ch_in=channel,
ch_out=channel // 2,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
......
......@@ -28,5 +28,6 @@ class ShapeSpec(
stride:
"""
def __new__(cls, *, channels=None, height=None, width=None, stride=None):
return super().__new__(cls, channels, height, width, stride)
def __new__(cls, channels=None, height=None, width=None, stride=None):
return super(ShapeSpec, cls).__new__(cls, channels, height, width,
stride)
......@@ -106,8 +106,7 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
x2 = paddle.minimum(px2, gx2)
y2 = paddle.minimum(py2, gy2)
overlap = (x2 - x1) * (y2 - y1)
overlap = overlap.clip(0)
overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
area1 = (px2 - px1) * (py2 - py1)
area1 = area1.clip(0)
......
......@@ -243,18 +243,14 @@ class ModelEMA(object):
self._decay = decay
model_dict = model.state_dict()
for k, v in self.state_dict.items():
if '_mean' not in k and '_variance' not in k:
v = decay * v + (1 - decay) * model_dict[k]
v.stop_gradient = True
self.state_dict[k] = v
else:
self.state_dict[k] = model_dict[k]
self.step += 1
def apply(self):
state_dict = dict()
for k, v in self.state_dict.items():
if '_mean' not in k and '_variance' not in k:
v = v / (1 - self._decay**self.step)
v.stop_gradient = True
state_dict[k] = v
......
......@@ -4,8 +4,13 @@ import numpy as np
import cv2
def get_det_res(bboxes, scores, labels, bbox_nums, image_id,
label_to_cat_id_map):
def get_det_res(bboxes,
scores,
labels,
bbox_nums,
image_id,
label_to_cat_id_map,
bias=0):
det_res = []
k = 0
for i in range(len(bbox_nums)):
......@@ -19,8 +24,8 @@ def get_det_res(bboxes, scores, labels, bbox_nums, image_id,
k = k + 1
xmin, ymin, xmax, ymax = box.tolist()
category_id = label_to_cat_id_map[label]
w = xmax - xmin
h = ymax - ymin
w = xmax - xmin + bias
h = ymax - ymin + bias
bbox = [xmin, ymin, w, h]
dt_res = {
'image_id': cur_image_id,
......
......@@ -163,7 +163,7 @@ def load_pretrain_weight(model,
model.backbone.set_dict(param_state_dict)
else:
ignore_set = set()
for name, weight in model_dict:
for name, weight in model_dict.items():
if name in param_state_dict:
if weight.shape != param_state_dict[name].shape:
param_state_dict.pop(name, None)
......
......@@ -47,7 +47,10 @@ def parse_args():
help="Evaluation directory, default is current directory.")
parser.add_argument(
'--json_eval', action='store_true', default=False, help='')
'--json_eval',
action='store_true',
default=False,
help='Whether to re eval with already exists bbox.json or mask.json')
parser.add_argument(
"--slim_config",
......@@ -55,6 +58,12 @@ def parse_args():
type=str,
help="Configuration file of slim method.")
# TODO: bias should be unified
parser.add_argument(
"--bias",
action="store_true",
help="whether add bias or not while getting w and h")
args = parser.parse_args()
return args
......@@ -77,6 +86,8 @@ def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
# TODO: bias should be unified
cfg['bias'] = 1 if FLAGS.bias else 0
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册