未验证 提交 b747596c 编写于 作者: W whcao 提交者: GitHub

support yolov6s_v2_qat_dis (#1419)

上级 bb7c1a70
......@@ -18,23 +18,26 @@
## 2.Benchmark
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 模型体积 | 预测时延<sup><small>FP32</small><sup><br><sup> |预测时延<sup><small>FP16</small><sup><br><sup> | 预测时延<sup><small>INT8</small><sup><br><sup> | 内存占用 | 显存占用 | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :--------: | :---------------------: | :----------------: | :----------------: |:----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | 1718MB | 705MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) |
| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | 736MB | 315MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | 1208MB | 555MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) |
| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | 736MB | 315MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | 1722MB | 917MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) |
| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | 827MB | 363MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | 827MB | 363MB | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | 738MB | 349MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) |
| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | 729MB | 315MB | - | - |
| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | 729MB | 315MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) |
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 模型体积 | 预测时延<sup><small>FP32</small><sup><br><sup> |预测时延<sup><small>FP16</small><sup><br><sup> | 预测时延<sup><small>INT8</small><sup><br><sup> | 内存占用 | 显存占用 | 配置文件 | Inference模型 |
|:--------------|:-------- |:--------: |:-----------------------:|:------:| :----------------: | :----------------: |:----------------: | :----------------: | :---------------: |:------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | 1718MB | 705MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) |
| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | 736MB | 315MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | 1208MB | 555MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) |
| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | 736MB | 315MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv6s_v2 | Base模型 | 640*640 | 43.4 | 67.4MB | 9.06ms | 2.90ms | - | 1208MB | 555MB | - | [Model](https://github.com/meituan/YOLOv6/releases/download/0.2.0/yolov6s.onnx) |
| YOLOv6s_v2 | 量化蒸馏训练 | 640*640 | **43.0** | 16.8MB | - | - | **1.83ms** | 736MB | 315MB | [config](./configs/yolov6s_v2_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_v2_0_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_v2_0_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | 1722MB | 917MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) |
| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | 827MB | 363MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | 827MB | 363MB | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant_onnx.tar) |
| | | | | | | | | |
| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | 738MB | 349MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) |
| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | 729MB | 315MB | - | - |
| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | 729MB | 315MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) |
说明:
- mAP的指标均在COCO val2017数据集中评测得到。
......
Global:
model_dir: ./yolov6s.onnx
image_path: None # If image_path is set, it will be trained directly based on unlabeled images, no need to set the COCO dataset path.
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
arch: YOLOv6
nms_num_top_k: 1000
Distillation:
alpha: 1.0
loss: soft_label
Quantization:
onnx_format: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 8000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00003
T_max: 8000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 0.00004
......@@ -51,6 +51,12 @@ def eval():
val_program, feed_target_names, fetch_targets = load_inference_model(
global_config["model_dir"], exe)
postprocess = YOLOPostProcess(
score_threshold=0.001,
nms_threshold=0.65,
multi_label=True,
num_top_k=global_config.get('nms_num_top_k', 30000))
bboxes_list, bbox_nums_list, image_id_list = [], [], []
with tqdm(
total=len(val_loader),
......@@ -62,8 +68,6 @@ def eval():
feed={feed_target_names[0]: data_all['image']},
fetch_list=fetch_targets,
return_numpy=False)
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
bboxes_list.append(res['bbox'])
bbox_nums_list.append(res['bbox_num'])
......
......@@ -80,17 +80,20 @@ class YOLOPostProcess(object):
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
num_top_k(int): Maximum number of boxes put into torchvision.ops.nums()
"""
def __init__(self,
score_threshold=0.25,
nms_threshold=0.5,
multi_label=False,
keep_top_k=300):
keep_top_k=300,
num_top_k=30000):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.multi_label = multi_label
self.keep_top_k = keep_top_k
self.num_top_k = num_top_k
def _xywh2xyxy(self, x):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
......@@ -103,7 +106,6 @@ class YOLOPostProcess(object):
def _non_max_suppression(self, prediction):
max_wh = 4096 # (pixels) minimum and maximum box width and height
nms_top_k = 30000
cand_boxes = prediction[..., 4] > self.score_threshold # candidates
output = [np.zeros((0, 6))] * prediction.shape[0]
......@@ -137,8 +139,8 @@ class YOLOPostProcess(object):
num_box = boxes.shape[0]
if not num_box:
continue
elif num_box > nms_top_k:
boxes = boxes[boxes[:, 4].argsort()[::-1][:nms_top_k]]
elif num_box > self.nms_top_k:
boxes = boxes[boxes[:, 4].argsort()[::-1][:self.nms_top_k]]
# Batched NMS
c = boxes[:, 5:6] * max_wh
......
......@@ -56,6 +56,11 @@ def reader_wrapper(reader, input_name='x2paddle_images'):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
bboxes_list, bbox_nums_list, image_id_list = [], [], []
postprocess = YOLOPostProcess(
score_threshold=0.001,
nms_threshold=0.65,
multi_label=True,
num_top_k=global_config.get('nms_num_top_k', 30000))
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
......@@ -66,9 +71,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
bboxes_list.append(res['bbox'])
bbox_nums_list.append(res['bbox_num'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册