未验证 提交 3ee7bde2 编写于 作者: G Guanghua Yu 提交者: GitHub

update PicoDet docs (#4333)

* update PicoDet docs
上级 488a6838
...@@ -5,44 +5,146 @@ ...@@ -5,44 +5,146 @@
We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU. We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU.
- 🌟 Higher mAP: The **first** model which within 1M parameter with mAP reaching 30+. - 🌟 Higher mAP: the **first** object detectors that surpass mAP(0.5:0.95) **30+** within 1M parameters when the input size is 416.
- 🚀 Faster latency: 114FPS on mobile ARM CPU. - 🚀 Faster latency: 114FPS on mobile ARM CPU.
- 😊 Deploy friendly: support PaddleLite/MNN/NCNN/OpenVINO and provide C++/Python/Android implementation. - 😊 Deploy friendly: support PaddleLite/MNN/NCNN/OpenVINO and provide C++/Python/Android implementation.
- 😍 Advanced algorithm: use the most advanced algorithms and innovate, such as ESNet, CSP-PAN, SimOTA with VFL, etc. - 😍 Advanced algorithm: use the most advanced algorithms and innovate, such as ESNet, CSP-PAN, SimOTA with VFL, etc.
### Comming soon
- [ ] More series of model, such as smaller or larger model.
- [ ] Pretrained models for more scenarios.
- [ ] More features in need.
## Requirements ## Requirements
- PaddlePaddle >= 2.1.2 - PaddlePaddle >= 2.1.2
- PaddleSlim >= 2.1.1 - PaddleSlim >= 2.1.1
## Comming soon
- [ ] More series of model, such as Smaller or larger model.
- [ ] Pretrained models for more scenarios.
- [ ] More features in need.
## Model Zoo ## Model Zoo
| Model | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | | Model | Input size | mAP<sup>val<br>0.5:0.95 | mAP<sup>val<br>0.5 | FLOPS<br><sup>(G) | Params<br><sup>(M) | Latency<br><sup>(ms) | download | config |
| :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | | :------------------------ | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: |
| PicoDet-S | 320*320 | 300e | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) | | PicoDet-S | 320*320 | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) |
| PicoDet-S | 416*416 | 300e | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) | | PicoDet-M | 320*320 | 30.9 | 45.7 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) |
| PicoDet-M | 320*320 | 300e | - | 41.2 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) | | PicoDet-L | 320*320 | 32.6 | 47.9 | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_320_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_320_coco.yml) |
| PicoDet-M | 416*416 | 300e | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) | | PicoDet-S | 416*416 | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) |
| PicoDet-M | 416*416 | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) |
| PicoDet-L | 416*416 | - | - | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_416_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_416_coco.yml) |
| PicoDet-L | 640*640 | - | - | -- | 13M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_l_640_coco.pdparams) &#124; [log](https://paddledet.bj.bcebos.com/logs/train_picodet_l_640_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_l_640_coco.yml) |
**Notes:** **Notes:**
- PicoDet inference speed is tested on Kirin 980 with 4 threads by arm8 and with FP16. - PicoDet inference speed is tested on Snapdragon 888(4xA78+4xA55) with 4 threads by arm8 and with FP16.
- PicoDet is trained on COCO train2017 dataset and evaluated on val2017. - PicoDet is trained on COCO train2017 dataset and evaluated on val2017.
- PicoDet used 4 GPUs for training and mini-batch size as 128 or 96 on each GPU. - PicoDet used 4 or 8 GPUs for training.
## Deployment
### Export and Convert model
<details>
<summary>1. Export model</summary>
```shell
cd PaddleDetection
python tools/export_model.py -c configs/picodet/picodet_s_320_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams --output_dir=inference_model
```
</details>
<details>
<summary>2. Convert to PaddleLite</summary>
- Install Paddlelite>=2.10.rc:
```shell
pip install paddlelite
```
- Convert model:
```shell
# FP32
paddle_lite_opt --model_dir=inference_model/picodet_s_320_coco --valid_targets=arm --optimize_out=picodet_s_320_coco_fp32
# FP16
paddle_lite_opt --model_dir=inference_model/picodet_s_320_coco --valid_targets=arm --optimize_out=picodet_s_320_coco_fp16 --enable_fp16=true
```
</details>
<details>
<summary>3. Convert to ONNX</summary>
- Install [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) >= 0.7 and ONNX > 1.10.1, for details, please refer to [Tutorials of Export ONNX Model](../../deploy/EXPORT_ONNX_MODEL.md)
```shell
pip install onnx
pip install paddle2onnx
```
- Convert model:
```shell
paddle2onnx --model_dir output_inference/picodet_s_320_coco/ \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--opset_version 11 \
--save_file picodet_s_320_coco.onnx
```
- Simplify ONNX model: use onnx-simplifier to simplify onnx model.
- Install onnx-simplifier >= 0.3.6:
```shell
pip install onnx-simplifier
```
- simplify onnx model:
```shell
python -m onnxsim picodet_s_320_coco.onnx picodet_s_processed.onnx
```
</details>
### Deploy
- PaddleInference demo [Python](../../deploy/python) & [C++](../../deploy/cpp)
- [PaddleLite C++ demo](../../deploy/lite)
- [NCNN C++ demo]()
- [MNN C++ demo]()
- [OpenVINO C++ demo]()
- [Android demo]()
## Slim
### quantization
<details>
<summary>Quant aware</summary>
Configure the quant config and start training:
```shell
python tools/train.py -c configs/picodet/picodet_s_320_coco.yml \
--slim_config configs/slim/quant/picodet_s_quant.yml --eval
```
</details>
<details>
<summary>Post quant</summary>
Configure the post quant config and start calibrate model:
```shell
python tools/posy_quant.py -c configs/picodet/picodet_s_320_coco.yml \
--slim_config configs/slim/posy_quant/picodet_s_quant.yml
```
</details>
## Citations ## Cite PiocDet
If you use PiocDet in your research, please cite our work by using the following BibTeX entry:
``` ```
@article{li2020generalized, comming soon
title={Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection},
author={Li, Xiang and Wang, Wenhai and Wu, Lijun and Chen, Shuo and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian},
journal={arXiv preprint arXiv:2006.04388},
year={2020}
}
``` ```
...@@ -31,7 +31,7 @@ PicoHead: ...@@ -31,7 +31,7 @@ PicoHead:
feat_in_chan: 128 feat_in_chan: 128
TrainReader: TrainReader:
batch_size: 48 batch_size: 32
LearningRate: LearningRate:
base_lr: 0.3 base_lr: 0.3
......
weights: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams
slim: PTQ
PTQ:
ptq_config: {
'activation_quantizer': 'HistQuantizer',
'upsample_bins': 127,
'hist_percent': 0.999}
quant_batch_num: 10
fuse: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams
slim: QAT slim: QAT
QAT: QAT:
...@@ -9,15 +9,15 @@ QAT: ...@@ -9,15 +9,15 @@ QAT:
'quantizable_layer_type': ['Conv2D', 'Linear']} 'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: False print_model: False
epoch: 80 epoch: 50
LearningRate: LearningRate:
base_lr: 0.001 base_lr: 0.001
schedulers: schedulers:
- !PiecewiseDecay - !PiecewiseDecay
gamma: 0.1 gamma: 0.1
milestones: milestones:
- 60 - 30
- 70 - 40
- !LinearWarmup - !LinearWarmup
start_factor: 0. start_factor: 0.
steps: 100 steps: 100
......
...@@ -11,6 +11,7 @@ PaddleDetection模型支持保存为ONNX格式,目前测试支持的列表如 ...@@ -11,6 +11,7 @@ PaddleDetection模型支持保存为ONNX格式,目前测试支持的列表如
| PAFNet | 11 |- | | PAFNet | 11 |- |
| TTFNet | 11 |-| | TTFNet | 11 |-|
| SSD | 11 |仅支持batch=1推理 | | SSD | 11 |仅支持batch=1推理 |
| PicoDet | 11 |仅支持batch=1推理 |
保存ONNX的功能由[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)提供,如在转换中有相关问题反馈,可在Paddle2ONNX的Github项目中通过[ISSUE](https://github.com/PaddlePaddle/Paddle2ONNX/issues)与工程师交流。 保存ONNX的功能由[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)提供,如在转换中有相关问题反馈,可在Paddle2ONNX的Github项目中通过[ISSUE](https://github.com/PaddlePaddle/Paddle2ONNX/issues)与工程师交流。
......
...@@ -196,7 +196,6 @@ class SimOTAAssigner(object): ...@@ -196,7 +196,6 @@ class SimOTAAssigner(object):
num_valid = valid_decoded_bbox.shape[0] num_valid = valid_decoded_bbox.shape[0]
pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes) pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes)
iou_cost = -paddle.log(pairwise_ious + eps)
if self.use_vfl: if self.use_vfl:
gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile( gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile(
[num_valid, 1]).reshape([-1]) [num_valid, 1]).reshape([-1])
...@@ -216,6 +215,7 @@ class SimOTAAssigner(object): ...@@ -216,6 +215,7 @@ class SimOTAAssigner(object):
paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF
) )
else: else:
iou_cost = -paddle.log(pairwise_ious + eps)
gt_onehot_label = (F.one_hot( gt_onehot_label = (F.one_hot(
gt_labels.squeeze(-1).cast(paddle.int64), gt_labels.squeeze(-1).cast(paddle.int64),
pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile( pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile(
......
...@@ -34,9 +34,9 @@ def varifocal_loss(pred, ...@@ -34,9 +34,9 @@ def varifocal_loss(pred,
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args: Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the pred (Tensor): The prediction with shape (N, C), C is the
number of classes number of classes
target (torch.Tensor): The learning target of the iou-aware target (Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes. classification score with shape (N, C), C is the number of classes.
alpha (float, optional): A balance factor for the negative part of alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal Loss. Varifocal Loss, which is different from the alpha of Focal Loss.
...@@ -108,27 +108,18 @@ class VarifocalLoss(nn.Layer): ...@@ -108,27 +108,18 @@ class VarifocalLoss(nn.Layer):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self, pred, target, weight=None, avg_factor=None):
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function. """Forward function.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None. prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None. the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns: Returns:
torch.Tensor: The calculated loss Tensor: The calculated loss
""" """
loss = self.loss_weight * varifocal_loss( loss = self.loss_weight * varifocal_loss(
pred, pred,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册