diff --git a/configs/picodet/README.md b/configs/picodet/README.md index b2c75a51748b7f213983c0b4f635fc82b1a975bb..9e807f6686125b920d4003e6ebd01f6a95cb3e89 100644 --- a/configs/picodet/README.md +++ b/configs/picodet/README.md @@ -5,44 +5,146 @@ We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU. -- 🌟 Higher mAP: The **first** model which within 1M parameter with mAP reaching 30+. +- 🌟 Higher mAP: the **first** object detectors that surpass mAP(0.5:0.95) **30+** within 1M parameters when the input size is 416. - 🚀 Faster latency: 114FPS on mobile ARM CPU. - 😊 Deploy friendly: support PaddleLite/MNN/NCNN/OpenVINO and provide C++/Python/Android implementation. - 😍 Advanced algorithm: use the most advanced algorithms and innovate, such as ESNet, CSP-PAN, SimOTA with VFL, etc. +### Comming soon +- [ ] More series of model, such as smaller or larger model. +- [ ] Pretrained models for more scenarios. +- [ ] More features in need. ## Requirements - PaddlePaddle >= 2.1.2 - PaddleSlim >= 2.1.1 -## Comming soon -- [ ] More series of model, such as Smaller or larger model. -- [ ] Pretrained models for more scenarios. -- [ ] More features in need. - ## Model Zoo -| Model | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | -| :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | -| PicoDet-S | 320*320 | 300e | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) | -| PicoDet-S | 416*416 | 300e | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) | -| PicoDet-M | 320*320 | 300e | - | 41.2 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) | -| PicoDet-M | 416*416 | 300e | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) | - +| Model | Input size | mAPval
0.5:0.95 | mAPval
0.5 | FLOPS
(G) | Params
(M) | Latency
(ms) | download | config | +| :------------------------ | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | +| PicoDet-S | 320*320 | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) | +| PicoDet-M | 320*320 | 30.9 | 45.7 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) | +| PicoDet-L | 320*320 | 32.6 | 47.9 | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_320_coco.yml) | +| PicoDet-S | 416*416 | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) | +| PicoDet-M | 416*416 | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) | +| PicoDet-L | 416*416 | - | - | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_416_coco.yml) | +| PicoDet-L | 640*640 | - | - | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_640_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_640_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_640_coco.yml) | **Notes:** -- PicoDet inference speed is tested on Kirin 980 with 4 threads by arm8 and with FP16. +- PicoDet inference speed is tested on Snapdragon 888(4xA78+4xA55) with 4 threads by arm8 and with FP16. - PicoDet is trained on COCO train2017 dataset and evaluated on val2017. -- PicoDet used 4 GPUs for training and mini-batch size as 128 or 96 on each GPU. +- PicoDet used 4 or 8 GPUs for training. + +## Deployment + +### Export and Convert model + +
+1. Export model + +```shell +cd PaddleDetection +python tools/export_model.py -c configs/picodet/picodet_s_320_coco.yml \ + -o weights=https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams --output_dir=inference_model +``` + +
+ +
+2. Convert to PaddleLite + +- Install Paddlelite>=2.10.rc: + +```shell +pip install paddlelite +``` + +- Convert model: + +```shell +# FP32 +paddle_lite_opt --model_dir=inference_model/picodet_s_320_coco --valid_targets=arm --optimize_out=picodet_s_320_coco_fp32 +# FP16 +paddle_lite_opt --model_dir=inference_model/picodet_s_320_coco --valid_targets=arm --optimize_out=picodet_s_320_coco_fp16 --enable_fp16=true +``` + +
+ +
+3. Convert to ONNX + +- Install [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) >= 0.7 and ONNX > 1.10.1, for details, please refer to [Tutorials of Export ONNX Model](../../deploy/EXPORT_ONNX_MODEL.md) + +```shell +pip install onnx +pip install paddle2onnx +``` + +- Convert model: + +```shell +paddle2onnx --model_dir output_inference/picodet_s_320_coco/ \ + --model_filename model.pdmodel \ + --params_filename model.pdiparams \ + --opset_version 11 \ + --save_file picodet_s_320_coco.onnx +``` + +- Simplify ONNX model: use onnx-simplifier to simplify onnx model. + + - Install onnx-simplifier >= 0.3.6: + ```shell + pip install onnx-simplifier + ``` + - simplify onnx model: + ```shell + python -m onnxsim picodet_s_320_coco.onnx picodet_s_processed.onnx + ``` + +
+ +### Deploy + +- PaddleInference demo [Python](../../deploy/python) & [C++](../../deploy/cpp) +- [PaddleLite C++ demo](../../deploy/lite) +- [NCNN C++ demo]() +- [MNN C++ demo]() +- [OpenVINO C++ demo]() +- [Android demo]() + +## Slim + +### quantization + +
+Quant aware + +Configure the quant config and start training: + +```shell +python tools/train.py -c configs/picodet/picodet_s_320_coco.yml \ + --slim_config configs/slim/quant/picodet_s_quant.yml --eval +``` + +
+ +
+Post quant + +Configure the post quant config and start calibrate model: + +```shell +python tools/posy_quant.py -c configs/picodet/picodet_s_320_coco.yml \ + --slim_config configs/slim/posy_quant/picodet_s_quant.yml +``` + +
-## Citations +## Cite PiocDet +If you use PiocDet in your research, please cite our work by using the following BibTeX entry: ``` -@article{li2020generalized, - title={Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection}, - author={Li, Xiang and Wang, Wenhai and Wu, Lijun and Chen, Shuo and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian}, - journal={arXiv preprint arXiv:2006.04388}, - year={2020} -} +comming soon ``` diff --git a/configs/picodet/picodet_l_640_coco.yml b/configs/picodet/picodet_l_640_coco.yml index 94cab25487732b4a711625547bb6699aae9aef97..2578eedd343e07d75c8b67ac4e746d956c9bcbe0 100644 --- a/configs/picodet/picodet_l_640_coco.yml +++ b/configs/picodet/picodet_l_640_coco.yml @@ -31,7 +31,7 @@ PicoHead: feat_in_chan: 128 TrainReader: - batch_size: 48 + batch_size: 32 LearningRate: base_lr: 0.3 diff --git a/configs/slim/post_quant/picodet_s_ptq.yml b/configs/slim/post_quant/picodet_s_ptq.yml new file mode 100644 index 0000000000000000000000000000000000000000..e1cf3ca6ab23accabf91b0d7294c0ab48accf693 --- /dev/null +++ b/configs/slim/post_quant/picodet_s_ptq.yml @@ -0,0 +1,10 @@ +weights: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams +slim: PTQ + +PTQ: + ptq_config: { + 'activation_quantizer': 'HistQuantizer', + 'upsample_bins': 127, + 'hist_percent': 0.999} + quant_batch_num: 10 + fuse: True diff --git a/configs/slim/quant/picodet_s_shufflenetv2_quant.yml b/configs/slim/quant/picodet_s_quant.yml similarity index 90% rename from configs/slim/quant/picodet_s_shufflenetv2_quant.yml rename to configs/slim/quant/picodet_s_quant.yml index 9d63519a3e5a1b424eb5faa3fd8fa895648429a9..099532ffc5c3791644ceda25db8c1f4581762d61 100644 --- a/configs/slim/quant/picodet_s_shufflenetv2_quant.yml +++ b/configs/slim/quant/picodet_s_quant.yml @@ -1,4 +1,4 @@ -pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams +pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams slim: QAT QAT: @@ -9,15 +9,15 @@ QAT: 'quantizable_layer_type': ['Conv2D', 'Linear']} print_model: False -epoch: 80 +epoch: 50 LearningRate: base_lr: 0.001 schedulers: - !PiecewiseDecay gamma: 0.1 milestones: - - 60 - - 70 + - 30 + - 40 - !LinearWarmup start_factor: 0. steps: 100 diff --git a/deploy/EXPORT_ONNX_MODEL.md b/deploy/EXPORT_ONNX_MODEL.md index ac6b81e0c2ce81ac0fb38c4e1b46c99fdaf4edc6..cad839c9a64f40d68af5b275b57b6eaee492f932 100644 --- a/deploy/EXPORT_ONNX_MODEL.md +++ b/deploy/EXPORT_ONNX_MODEL.md @@ -11,6 +11,7 @@ PaddleDetection模型支持保存为ONNX格式,目前测试支持的列表如 | PAFNet | 11 |- | | TTFNet | 11 |-| | SSD | 11 |仅支持batch=1推理 | +| PicoDet | 11 |仅支持batch=1推理 | 保存ONNX的功能由[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)提供,如在转换中有相关问题反馈,可在Paddle2ONNX的Github项目中通过[ISSUE](https://github.com/PaddlePaddle/Paddle2ONNX/issues)与工程师交流。 diff --git a/ppdet/modeling/assigners/simota_assigner.py b/ppdet/modeling/assigners/simota_assigner.py index de4e89c8cae980abe0751b30af06e7b55137a43a..c0f337e46401ad5f5ff21b8d90bc81e3eb47b199 100644 --- a/ppdet/modeling/assigners/simota_assigner.py +++ b/ppdet/modeling/assigners/simota_assigner.py @@ -196,7 +196,6 @@ class SimOTAAssigner(object): num_valid = valid_decoded_bbox.shape[0] pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes) - iou_cost = -paddle.log(pairwise_ious + eps) if self.use_vfl: gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile( [num_valid, 1]).reshape([-1]) @@ -216,6 +215,7 @@ class SimOTAAssigner(object): paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF ) else: + iou_cost = -paddle.log(pairwise_ious + eps) gt_onehot_label = (F.one_hot( gt_labels.squeeze(-1).cast(paddle.int64), pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile( diff --git a/ppdet/modeling/losses/varifocal_loss.py b/ppdet/modeling/losses/varifocal_loss.py index 220c3b07287a79a4291f83cea3fc0261e1fcb0c6..07716a016b7ee2fd050d8c443582ad7b8c7c8b0b 100644 --- a/ppdet/modeling/losses/varifocal_loss.py +++ b/ppdet/modeling/losses/varifocal_loss.py @@ -34,9 +34,9 @@ def varifocal_loss(pred, """`Varifocal Loss `_ Args: - pred (torch.Tensor): The prediction with shape (N, C), C is the + pred (Tensor): The prediction with shape (N, C), C is the number of classes - target (torch.Tensor): The learning target of the iou-aware + target (Tensor): The learning target of the iou-aware classification score with shape (N, C), C is the number of classes. alpha (float, optional): A balance factor for the negative part of Varifocal Loss, which is different from the alpha of Focal Loss. @@ -108,27 +108,18 @@ class VarifocalLoss(nn.Layer): self.reduction = reduction self.loss_weight = loss_weight - def forward(self, - pred, - target, - weight=None, - avg_factor=None, - reduction_override=None): + def forward(self, pred, target, weight=None, avg_factor=None): """Forward function. Args: - pred (torch.Tensor): The prediction. - target (torch.Tensor): The learning target of the prediction. - weight (torch.Tensor, optional): The weight of loss for each + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. - reduction_override (str, optional): The reduction method used to - override the original reduction method of the loss. - Options are "none", "mean" and "sum". - Returns: - torch.Tensor: The calculated loss + Tensor: The calculated loss """ loss = self.loss_weight * varifocal_loss( pred,