未验证 提交 fd85f9db 编写于 作者: G Guanghua Yu 提交者: GitHub

fix YOLO series act demo some bug (#1359)

上级 3d0755b1
......@@ -13,6 +13,7 @@ Distillation:
loss: soft_label
Quantization:
onnx_format: true
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
......
......@@ -13,6 +13,7 @@ Distillation:
loss: soft_label
Quantization:
onnx_format: true
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
......
......@@ -33,7 +33,7 @@
| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) |
| | | | | | | | | |
| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) |
| YOLOv7-Tiny | 离线量化 | 640*640 | - | 6.1MB | - | - | 1.68ms | - | - |
| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | - | - |
| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.onnx) |
说明:
......@@ -43,15 +43,15 @@
## 3. 自动压缩流程
#### 3.1 准备环境
- PaddlePaddle develop每日版本 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html)下载安装)
- PaddlePaddle >= 2.3.2版本 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)根据相应环境的安装指令进行安装)
- PaddleSlim develop 版本
(1)安装paddlepaddle
```shell
(1)安装paddlepaddle
```
# CPU
pip install paddlepaddle
pip install paddlepaddle==2.3.2
# GPU
pip install paddlepaddle-gpu
pip install paddlepaddle-gpu==2.3.2
```
(2)安装paddleslim:
......@@ -90,22 +90,22 @@ dataset/coco/
- YOLOv5:
可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)
```shell
python export.py --weights yolov5s.pt --include onnx
```
本示例模型使用[ultralytics/yolov5](https://github.com/ultralytics/yolov5)的master分支导出,要求v6.1之后的ONNX模型,可以根据官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)
```shell
python export.py --weights yolov5s.pt --include onnx
```
- YOLOv6:
可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)
可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)
- YOLOv7: 可通过[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)的导出脚本来准备ONNX模型,具体步骤如下:
```shell
git clone https://github.com/WongKinYiu/yolov7.git
python export.py --weights yolov7-tiny.pt --grid
```
```shell
git clone https://github.com/WongKinYiu/yolov7.git
python export.py --weights yolov7-tiny.pt --grid
```
**注意**:目前ACT支持不带NMS模型,使用如上命令导出即可。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx)
**注意**:目前ACT支持**不带NMS**模型,使用如上命令导出即可。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx)
#### 3.4 自动压缩并产出模型
......@@ -138,25 +138,11 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml
#### 导出至ONNX使用TensorRT部署
- 首先安装Paddle2onnx:
```shell
pip install paddle2onnx==1.0.0rc3
```
- 然后将量化模型导出至ONNX:
```shell
paddle2onnx --model_dir output/ \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--opset_version 13 \
--enable_onnx_checker True \
--save_file yolov7_quant.onnx \
--deploy_backend tensorrt
```
执行完自动压缩后会默认在`save_dir`中生成`quant_model.onnx`的ONNX模型文件,可以直接使用TensorRT测试脚本进行验证。
- 进行测试:
```shell
python yolov7_onnx_trt.py --model_path=yolov7_quant.onnx --image_file=images/000000570688.jpg --precision=int8
python yolov7_onnx_trt.py --model_path=output/quant_model.onnx --image_file=images/000000570688.jpg --precision=int8
```
#### Paddle-TensorRT部署
......
......@@ -15,7 +15,6 @@ Distillation:
Quantization:
onnx_format: true
use_pact: true
onnx_format: False
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
......
......@@ -32,6 +32,8 @@ def argsparser():
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--batch_size', type=int, default=1, help="Batch size of model input.")
parser.add_argument(
'--devices',
type=str,
......@@ -83,7 +85,8 @@ def main():
anno_path=global_config['val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(dataset, batch_size=1)
val_loader = paddle.io.DataLoader(
dataset, batch_size=FLAGS.batch_size, drop_last=True)
eval()
......
......@@ -159,8 +159,8 @@ class YOLOPostProcess(object):
if len(pred.shape) == 1:
pred = pred[np.newaxis, :]
pred_bboxes = pred[:, :4]
scale_factor = np.tile(scale_factor[i][::-1], (1, 2))
pred_bboxes /= scale_factor
scale = np.tile(scale_factor[i][::-1], (2))
pred_bboxes /= scale
bbox = np.concatenate(
[
pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis],
......
arch: YOLOv5
model_dir: ./yolov5s.onnx
dataset_dir: /dataset/coco/
dataset_dir: dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
......
......@@ -20,7 +20,7 @@ from tqdm import tqdm
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import load_inference_model
from post_process import YOLOv6PostProcess, coco_metric
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset
......@@ -32,6 +32,8 @@ def argsparser():
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--batch_size', type=int, default=1, help="Batch size of model input.")
parser.add_argument(
'--devices',
type=str,
......@@ -60,8 +62,7 @@ def eval():
feed={feed_target_names[0]: data_all['image']},
fetch_list=fetch_targets,
return_numpy=False)
res = {}
postprocess = YOLOv6PostProcess(
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
bboxes_list.append(res['bbox'])
......@@ -83,7 +84,8 @@ def main():
anno_path=config['val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(dataset, batch_size=1)
val_loader = paddle.io.DataLoader(
dataset, batch_size=FLAGS.batch_size, drop_last=True)
eval()
......
......@@ -70,9 +70,9 @@ def nms(boxes, scores, iou_threshold):
return keep
class YOLOv6PostProcess(object):
class YOLOPostProcess(object):
"""
Post process of YOLOv6 network.
Post process of YOLO serise network.
args:
score_threshold(float): Threshold to filter out bounding boxes with low
confidence score. If not provided, consider all boxes.
......@@ -159,8 +159,8 @@ class YOLOv6PostProcess(object):
if len(pred.shape) == 1:
pred = pred[np.newaxis, :]
pred_bboxes = pred[:, :4]
scale_factor = np.tile(scale_factor[i][::-1], (1, 2))
pred_bboxes /= scale_factor
scale = np.tile(scale_factor[i][::-1], (2))
pred_bboxes /= scale
bbox = np.concatenate(
[
pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册