diff --git a/configs/ppyolo/README.md b/configs/ppyolo/README.md index 57b9b399bea40e3efd0543bf4b9f83535d17df1b..3d8f2af9ec805fda072d7d9a80f4babde21ddd8a 100644 --- a/configs/ppyolo/README.md +++ b/configs/ppyolo/README.md @@ -67,22 +67,26 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods: ### PP-YOLO for mobile -| Model | GPU number | images/GPU | Model Size | input shape | Box APval | Kirin 990(FPS) | download | config | -|:----------------------------:|:----------:|:----------:| :--------: | :----------:| :------------------: | :------------: | :------: | :-----: | -| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) | -| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | +| Model | GPU number | images/GPU | Model Size | input shape | Box APval | Box AP50val | Kirin 990 1xCore(FPS) | download | inference model download | config | +|:----------------------------:|:----------:|:----------:| :--------: | :----------:| :------------------: | :--------------------: | :-------------------: | :------: | :----------------------: | :-----: | +| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 23.2 | 42.6 | 15.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.tar) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) | +| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 17.2 | 33.8 | 28.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.tar) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | **Notes:** -- PP-YOLO_MobileNetV3 is trained on COCO train2017 datast and evaluated on val2017 dataset,Box APval is evaluation results of `mAP(IoU=0.5:0.95)`. +- PP-YOLO_MobileNetV3 is trained on COCO train2017 datast and evaluated on val2017 dataset,Box APval is evaluation results of `mAP(IoU=0.5:0.95)`, Box APval is evaluation results of `mAP(IoU=0.5)`. - PP-YOLO_MobileNetV3 used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md). - PP-YOLO_MobileNetV3 inference speed is tested on Kirin 990 with 1 thread. ### Slim PP-YOLO -| Model | GPU number | images/GPU | Prune Ratio | Teacher Model | Model Size | input shape | Box APval | Kirin 990(FPS) | download | config | -|:----------------------------:|:----------:|:----------:| :---------: | :-----------------------: | :--------: | :----------:| :------------------: | :------------: | :------: | :-----: | -| PP-YOLO_MobileNetV3_small | 4 | 32 | 75% | PP-YOLO_MobileNetV3_large | 4.1MB | 320 | 14.4 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | +| Model | GPU number | images/GPU | Prune Ratio | Teacher Model | Model Size | input shape | Box APval | Kirin 990 1xCore(FPS) | download | inference model download | config | +|:----------------------------:|:----------:|:----------:| :---------: | :-----------------------: | :--------: | :----------:| :------------------: | :-------------------: | :------: | :----------------------: | :-----: | +| PP-YOLO_MobileNetV3_small | 4 | 32 | 75% | PP-YOLO_MobileNetV3_large | 4.2MB | 320 | 16.2 | 39.8 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small_prune75_distillby_mobilenet_v3_large.pdparams) | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small_prune75_distillby_mobilenet_v3_large.tar) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | + +- Slim PP-YOLO is trained by slim traing method from [Distill pruned model](../../slim/extentions/distill_pruned_model/README.md),distill training pruned PP-YOLO_MobileNetV3_small model with PP-YOLO_MobileNetV3_large model as the teacher model +- Pruning detectiom head of PP-YOLO model with ratio as 75%, while the arguments are `--pruned_params="yolo_block.0.2.conv.weights,yolo_block.0.tip.conv.weights,yolo_block.1.2.conv.weights,yolo_block.1.tip.conv.weights" --pruned_ratios="0.75,0.75,0.75,0.75"` +- For Slim PP-YOLO training, evaluation, inference and model exporting, please see [Distill pruned model](../../slim/extentions/distill_pruned_model/README.md) ### PP-YOLO on Pascal VOC diff --git a/configs/ppyolo/README_cn.md b/configs/ppyolo/README_cn.md index e5152420cd81482353f16b546c3173f02be18ea8..68ca89d41a8f11d64ff56583a34bd145c5eee63c 100644 --- a/configs/ppyolo/README_cn.md +++ b/configs/ppyolo/README_cn.md @@ -68,15 +68,25 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度: ### PP-YOLO 轻量级模型 -| 模型 | GPU个数 | 每GPU图片个数 | 模型体积 | 输入尺寸 | Box APval | Kirin 990 (FPS) | 模型下载 | 配置文件 | -|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------: |------------: | :---------------------: | -| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) | -| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | +| 模型 | GPU个数 | 每GPU图片个数 | 模型体积 | 输入尺寸 | Box APval | Box AP50val | Kirin 990 1xCore (FPS) | 模型下载 | 预测模型下载 | 配置文件 | +|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :--------------------: | :--------------------: | :------: | :----------: | :------: | +| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 23.2 | 42.6 | 14.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.tar) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) | +| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 17.2 | 33.8 | 21.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.tar) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | -- PP-YOLO_MobileNetV3 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box AP50val为`mAP(IoU=0.5:0.95)`评估结果。 +- PP-YOLO_MobileNetV3 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box APval为`mAP(IoU=0.5:0.95)`评估结果, Box AP50val为`mAP(IoU=0.5)`评估结果。 - PP-YOLO_MobileNetV3 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。 - PP-YOLO_MobileNetV3 模型推理速度测试环境配置为麒麟990芯片单线程。 +### PP-YOLO 轻量级裁剪模型 + +| 模型 | GPU 个数 | 每GPU图片个数 | 裁剪率 | Teacher模型 | 模型体积 | 输入尺寸 | Box APval | Kirin 990 1xCore (FPS) | 模型下载 | 预测模型下载 | 配置文件 | +|:----------------------------:|:----------:|:-------------:| :---------: | :-----------------------: | :--------: | :----------:| :------------------: | :--------------------: | :------: | :----------: | :------: | +| PP-YOLO_MobileNetV3_small | 4 | 32 | 75% | PP-YOLO_MobileNetV3_large | 4.2MB | 320 | 16.2 | 39.8 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small_prune75_distillby_mobilenet_v3_large.pdparams) | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small_prune75_distillby_mobilenet_v3_large.tar) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) | + +- PP-YOLO 轻量级裁剪模型采用[蒸馏通道剪裁模型](../../slim/extentions/distill_pruned_model/README.md) 的方式训练得到,基于 PP-YOLO_MobileNetV3_small 模型对Head部分做卷积通道剪裁后使用 PP-YOLO_MobileNetV3_large 模型进行蒸馏训练 +- 卷积通道检测对Head部分剪裁掉75%的通道数,及剪裁参数为`--pruned_params="yolo_block.0.2.conv.weights,yolo_block.0.tip.conv.weights,yolo_block.1.2.conv.weights,yolo_block.1.tip.conv.weights" --pruned_ratios="0.75,0.75,0.75,0.75"` +- PP-YOLO 轻量级裁剪模型的训练、评估、预测及模型导出方法见[蒸馏通道剪裁模型](../../slim/extentions/distill_pruned_model/README.md) + ### Pascal VOC数据集上的PP-YOLO PP-YOLO在Pascal VOC数据集上训练模型如下: diff --git a/configs/ppyolo/ppyolo_mobilenet_v3_large.yml b/configs/ppyolo/ppyolo_mobilenet_v3_large.yml index 4be1f4677f1357a5de35d8e7e57accd23b22aca3..262c6c0b94032d45e9d544f25cf674047dc08048 100755 --- a/configs/ppyolo/ppyolo_mobilenet_v3_large.yml +++ b/configs/ppyolo/ppyolo_mobilenet_v3_large.yml @@ -29,10 +29,11 @@ MobileNetV3: YOLOv3Head: anchor_masks: [[3, 4, 5], [0, 1, 2]] - anchors: [[10, 14], [23, 27], [37, 58], - [81, 82], [135, 169], [344, 319]] + anchors: [[11, 18], [34, 47], [51, 126], + [115, 71], [120, 195], [254, 235]] norm_decay: 0. conv_block_num: 0 + coord_conv: true scale_x_y: 1.05 yolo_loss: YOLOv3Loss spp: true @@ -42,11 +43,11 @@ YOLOv3Head: nms_threshold: 0.45 nms_top_k: 1000 normalized: false - score_threshold: 0.01 + score_threshold: 0.005 drop_block: true YOLOv3Loss: - ignore_thresh: 0.7 + ignore_thresh: 0.5 scale_x_y: 1.05 label_smooth: false use_fine_grained_loss: true @@ -54,11 +55,11 @@ YOLOv3Loss: IouLoss: loss_weight: 2.5 - max_height: 608 - max_width: 608 + max_height: 512 + max_width: 512 LearningRate: - base_lr: 0.00666 + base_lr: 0.005 schedulers: - !PiecewiseDecay gamma: 0.1 @@ -81,7 +82,7 @@ _READER_: 'ppyolo_reader.yml' TrainReader: inputs_def: fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] - num_max_boxes: 50 + num_max_boxes: 90 dataset: !COCODataSet image_dir: train2017 @@ -103,11 +104,11 @@ TrainReader: is_normalized: false - !NormalizeBox {} - !PadBox - num_max_boxes: 50 + num_max_boxes: 90 - !BboxXYXY2XYWH {} batch_transforms: - !RandomShape - sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + sizes: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512] random_inter: True - !NormalizeImage mean: [0.485, 0.456, 0.406] @@ -122,12 +123,14 @@ TrainReader: # is set as false - !Gt2YoloTarget anchor_masks: [[3, 4, 5], [0, 1, 2]] - anchors: [[10, 14], [23, 27], [37, 58], - [81, 82], [135, 169], [344, 319]] + anchors: [[11, 18], [34, 47], [51, 126], + [115, 71], [120, 195], [254, 235]] downsample_ratios: [32, 16] + iou_thresh: 0.25 + num_classes: 80 batch_size: 32 shuffle: true - mixup_epoch: 500 + mixup_epoch: 200 drop_last: true worker_num: 8 bufsize: 4 @@ -136,7 +139,7 @@ TrainReader: EvalReader: inputs_def: fields: ['image', 'im_size', 'im_id'] - num_max_boxes: 50 + num_max_boxes: 90 dataset: !COCODataSet image_dir: val2017 @@ -155,11 +158,35 @@ EvalReader: is_scale: True is_channel_first: false - !PadBox - num_max_boxes: 50 + num_max_boxes: 90 - !Permute to_bgr: false channel_first: True - batch_size: 8 + batch_size: 1 drop_empty: false worker_num: 2 bufsize: 4 + +TestReader: + inputs_def: + image_shape: [3, 320, 320] + fields: ['image', 'im_size', 'im_id'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 320 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/configs/ppyolo/ppyolo_mobilenet_v3_small.yml b/configs/ppyolo/ppyolo_mobilenet_v3_small.yml index eb70c8d2656f8746a04d2fcda51fa80f9c02c0e8..9c3e27976e285386331d8d52004f1a79ac11aec0 100755 --- a/configs/ppyolo/ppyolo_mobilenet_v3_small.yml +++ b/configs/ppyolo/ppyolo_mobilenet_v3_small.yml @@ -29,10 +29,11 @@ MobileNetV3: YOLOv3Head: anchor_masks: [[3, 4, 5], [0, 1, 2]] - anchors: [[10, 14], [23, 27], [37, 58], - [81, 82], [135, 169], [344, 319]] + anchors: [[11, 18], [34, 47], [51, 126], + [115, 71], [120, 195], [254, 235]] norm_decay: 0. conv_block_num: 0 + coord_conv: true scale_x_y: 1.05 yolo_loss: YOLOv3Loss spp: true @@ -42,11 +43,11 @@ YOLOv3Head: nms_threshold: 0.45 nms_top_k: 1000 normalized: false - score_threshold: 0.01 + score_threshold: 0.005 drop_block: true YOLOv3Loss: - ignore_thresh: 0.7 + ignore_thresh: 0.5 scale_x_y: 1.05 label_smooth: false use_fine_grained_loss: true @@ -54,11 +55,11 @@ YOLOv3Loss: IouLoss: loss_weight: 2.5 - max_height: 608 - max_width: 608 + max_height: 512 + max_width: 512 LearningRate: - base_lr: 0.00666 + base_lr: 0.005 schedulers: - !PiecewiseDecay gamma: 0.1 @@ -81,7 +82,7 @@ _READER_: 'ppyolo_reader.yml' TrainReader: inputs_def: fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] - num_max_boxes: 50 + num_max_boxes: 90 dataset: !COCODataSet image_dir: train2017 @@ -103,11 +104,11 @@ TrainReader: is_normalized: false - !NormalizeBox {} - !PadBox - num_max_boxes: 50 + num_max_boxes: 90 - !BboxXYXY2XYWH {} batch_transforms: - !RandomShape - sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + sizes: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512] random_inter: True - !NormalizeImage mean: [0.485, 0.456, 0.406] @@ -122,12 +123,14 @@ TrainReader: # is set as false - !Gt2YoloTarget anchor_masks: [[3, 4, 5], [0, 1, 2]] - anchors: [[10, 14], [23, 27], [37, 58], - [81, 82], [135, 169], [344, 319]] + anchors: [[11, 18], [34, 47], [51, 126], + [115, 71], [120, 195], [254, 235]] downsample_ratios: [32, 16] + iou_thresh: 0.25 + num_classes: 80 batch_size: 32 shuffle: true - mixup_epoch: 500 + mixup_epoch: 200 drop_last: true worker_num: 8 bufsize: 4 @@ -136,7 +139,7 @@ TrainReader: EvalReader: inputs_def: fields: ['image', 'im_size', 'im_id'] - num_max_boxes: 50 + num_max_boxes: 90 dataset: !COCODataSet image_dir: val2017 @@ -155,11 +158,35 @@ EvalReader: is_scale: True is_channel_first: false - !PadBox - num_max_boxes: 50 + num_max_boxes: 90 - !Permute to_bgr: false channel_first: True - batch_size: 8 + batch_size: 1 drop_empty: false worker_num: 2 bufsize: 4 + +TestReader: + inputs_def: + image_shape: [3, 320, 320] + fields: ['image', 'im_size', 'im_id'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 320 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index c0e9bd6e0e4af3940326d0a99264fda745292ba5..40267e1cac8751b212b3199c4b34654b7848832c 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -286,7 +286,8 @@ class Gt2YoloTarget(BaseOperator): iou = jaccard_overlap( [0., 0., gw, gh], [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]]) - if iou > self.iou_thresh: + if iou > self.iou_thresh and target[idx, 5, gj, + gi] == 0.: # x, y, w, h, scale target[idx, 0, gj, gi] = gx * grid_w - gi target[idx, 1, gj, gi] = gy * grid_h - gj diff --git a/ppdet/modeling/anchor_heads/yolo_head.py b/ppdet/modeling/anchor_heads/yolo_head.py index a0c3d2bc40f97c24b5469df72539225eef5d77a1..df8a375cf18c94f7d0e0746ece9c4a3ab48f58b7 100644 --- a/ppdet/modeling/anchor_heads/yolo_head.py +++ b/ppdet/modeling/anchor_heads/yolo_head.py @@ -241,10 +241,11 @@ class YOLOv3Head(object): padding=0, name='{}.{}.0'.format(name, j)) if self.use_spp and is_first and j == 1: + c = conv.shape[1] conv = self._spp_module(conv, name="spp") conv = self._conv_bn( conv, - 512, + c, filter_size=1, stride=1, padding=0, @@ -264,7 +265,15 @@ class YOLOv3Head(object): is_test=is_test) if self.use_spp and conv_block_num == 0 and is_first: + c = conv.shape[1] conv = self._spp_module(conv, name="spp") + conv = self._conv_bn( + conv, + c, + filter_size=1, + stride=1, + padding=0, + name='{}.spp.conv'.format(name)) if self.drop_block and (is_first or conv_block_num == 0): conv = DropBlock(