From 407d47b33b85b55b97ad248de23f902cd764bf97 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Tue, 16 Aug 2022 14:38:47 +0800 Subject: [PATCH] Rearrange Some docs & Add Quant Analysis API and Demo (#1342) --- example/auto_compression/detection/eval.py | 2 +- .../auto_compression/detection/post_quant.py | 2 +- example/auto_compression/detection/run.py | 5 +- .../image_classification/eval.py | 2 +- .../image_classification/infer.py | 2 +- .../image_classification/run.py | 2 +- example/auto_compression/nlp/run.py | 2 +- .../pytorch_huggingface/run.py | 2 +- .../auto_compression/pytorch_yolov5/eval.py | 2 +- .../pytorch_yolov5/post_quant.py | 3 +- .../auto_compression/pytorch_yolov5/run.py | 5 +- .../auto_compression/pytorch_yolov6/eval.py | 2 +- .../pytorch_yolov6/post_quant.py | 3 +- .../auto_compression/pytorch_yolov6/run.py | 5 +- .../auto_compression/pytorch_yolov7/eval.py | 4 +- .../pytorch_yolov7/post_quant.py | 2 +- .../auto_compression/pytorch_yolov7/run.py | 5 +- .../semantic_segmentation/run.py | 2 +- .../tensorflow_mobilenet/eval.py | 5 +- .../tensorflow_mobilenet/run.py | 5 +- example/full_quantization/detection/eval.py | 2 +- example/full_quantization/detection/run.py | 2 +- .../pytorch_yolo_series/README.md | 138 ++++++++ .../pytorch_yolo_series/analysis.py | 115 ++++++ .../configs/yolov5s_ptq.yaml | 7 + .../configs/yolov6s_analysis.yaml | 10 + .../configs/yolov6s_analyzed_ptq.yaml | 7 + .../configs/yolov6s_ptq.yaml | 7 + .../configs/yolov7s_ptq.yaml | 6 + .../pytorch_yolo_series/dataset.py | 113 ++++++ .../pytorch_yolo_series/eval.py | 99 ++++++ .../images/sensitivity_rank.png | Bin 0 -> 25389 bytes .../pytorch_yolo_series/post_process.py | 231 ++++++++++++ .../pytorch_yolo_series/post_quant.py | 93 +++++ paddleslim/auto_compression/__init__.py | 14 +- paddleslim/auto_compression/compressor.py | 6 +- paddleslim/auto_compression/config_helpers.py | 46 +-- .../create_compressed_program.py | 10 +- paddleslim/auto_compression/utils/__init__.py | 6 - paddleslim/auto_compression/utils/fake_ptq.py | 2 +- .../auto_compression/utils/load_model.py | 90 ----- paddleslim/auto_compression/utils/predict.py | 4 +- .../auto_compression/utils/prune_model.py | 7 +- paddleslim/common/__init__.py | 8 +- paddleslim/common/config_helper.py | 60 ++++ .../utils => common}/dataloader.py | 0 .../{convert_model.py => load_model.py} | 76 +++- paddleslim/quant/analysis.py | 331 ++++++++++++++++++ 48 files changed, 1360 insertions(+), 192 deletions(-) create mode 100644 example/post_training_quantization/pytorch_yolo_series/README.md create mode 100644 example/post_training_quantization/pytorch_yolo_series/analysis.py create mode 100644 example/post_training_quantization/pytorch_yolo_series/configs/yolov5s_ptq.yaml create mode 100644 example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml create mode 100644 example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analyzed_ptq.yaml create mode 100644 example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_ptq.yaml create mode 100644 example/post_training_quantization/pytorch_yolo_series/configs/yolov7s_ptq.yaml create mode 100644 example/post_training_quantization/pytorch_yolo_series/dataset.py create mode 100644 example/post_training_quantization/pytorch_yolo_series/eval.py create mode 100644 example/post_training_quantization/pytorch_yolo_series/images/sensitivity_rank.png create mode 100644 example/post_training_quantization/pytorch_yolo_series/post_process.py create mode 100644 example/post_training_quantization/pytorch_yolo_series/post_quant.py delete mode 100644 paddleslim/auto_compression/utils/load_model.py create mode 100644 paddleslim/common/config_helper.py rename paddleslim/{auto_compression/utils => common}/dataloader.py (100%) rename paddleslim/common/{convert_model.py => load_model.py} (58%) create mode 100644 paddleslim/quant/analysis.py diff --git a/example/auto_compression/detection/eval.py b/example/auto_compression/detection/eval.py index d80f3cfe..fc0c09ae 100644 --- a/example/auto_compression/detection/eval.py +++ b/example/auto_compression/detection/eval.py @@ -20,7 +20,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from keypoint_utils import keypoint_post_process from post_process import PPYOLOEPostProcess diff --git a/example/auto_compression/detection/post_quant.py b/example/auto_compression/detection/post_quant.py index b3a70900..edc7d2fe 100644 --- a/example/auto_compression/detection/post_quant.py +++ b/example/auto_compression/detection/post_quant.py @@ -19,7 +19,7 @@ import argparse import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.quant import quant_post_static diff --git a/example/auto_compression/detection/run.py b/example/auto_compression/detection/run.py index a3c46d47..6a4838ca 100644 --- a/example/auto_compression/detection/run.py +++ b/example/auto_compression/detection/run.py @@ -20,7 +20,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression from keypoint_utils import keypoint_post_process from post_process import PPYOLOEPostProcess @@ -126,7 +126,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(FLAGS.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] reader_cfg = load_config(global_config['reader_config']) diff --git a/example/auto_compression/image_classification/eval.py b/example/auto_compression/image_classification/eval.py index d0e0c3d1..9cd9b4a3 100644 --- a/example/auto_compression/image_classification/eval.py +++ b/example/auto_compression/image_classification/eval.py @@ -23,7 +23,7 @@ import paddle import paddle.nn as nn from paddle.io import DataLoader from imagenet_reader import ImageNetDataset -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config def argsparser(): diff --git a/example/auto_compression/image_classification/infer.py b/example/auto_compression/image_classification/infer.py index 5060115c..46eb7115 100644 --- a/example/auto_compression/image_classification/infer.py +++ b/example/auto_compression/image_classification/infer.py @@ -22,7 +22,7 @@ import yaml from utils import preprocess, postprocess import paddle from paddle.inference import create_predictor -from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.common import load_config def argsparser(): diff --git a/example/auto_compression/image_classification/run.py b/example/auto_compression/image_classification/run.py index d8da1a9f..7d660431 100644 --- a/example/auto_compression/image_classification/run.py +++ b/example/auto_compression/image_classification/run.py @@ -24,7 +24,7 @@ import paddle import paddle.nn as nn from paddle.io import DataLoader from imagenet_reader import ImageNetDataset -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression diff --git a/example/auto_compression/nlp/run.py b/example/auto_compression/nlp/run.py index 769f58ef..c70a3a34 100644 --- a/example/auto_compression/nlp/run.py +++ b/example/auto_compression/nlp/run.py @@ -15,7 +15,7 @@ from paddlenlp.datasets import load_dataset from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.data.sampler import SamplerHelper from paddlenlp.metrics import Mcc, PearsonAndSpearman -from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.common import load_config from paddleslim.auto_compression.compressor import AutoCompression diff --git a/example/auto_compression/pytorch_huggingface/run.py b/example/auto_compression/pytorch_huggingface/run.py index 4da4e703..0c730dff 100644 --- a/example/auto_compression/pytorch_huggingface/run.py +++ b/example/auto_compression/pytorch_huggingface/run.py @@ -27,7 +27,7 @@ from paddlenlp.transformers import AutoModelForTokenClassification, AutoTokenize from paddlenlp.datasets import load_dataset from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression.compressor import AutoCompression diff --git a/example/auto_compression/pytorch_yolov5/eval.py b/example/auto_compression/pytorch_yolov5/eval.py index 42f2e121..68461c99 100644 --- a/example/auto_compression/pytorch_yolov5/eval.py +++ b/example/auto_compression/pytorch_yolov5/eval.py @@ -18,7 +18,7 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.common import load_onnx_model from post_process import YOLOv5PostProcess, coco_metric from dataset import COCOValDataset diff --git a/example/auto_compression/pytorch_yolov5/post_quant.py b/example/auto_compression/pytorch_yolov5/post_quant.py index 84db4f98..97f46741 100644 --- a/example/auto_compression/pytorch_yolov5/post_quant.py +++ b/example/auto_compression/pytorch_yolov5/post_quant.py @@ -17,11 +17,12 @@ import sys import numpy as np import argparse import paddle -from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.common import load_config from paddleslim.common import load_onnx_model from paddleslim.quant import quant_post_static from dataset import COCOTrainDataset + def argsparser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( diff --git a/example/auto_compression/pytorch_yolov5/run.py b/example/auto_compression/pytorch_yolov5/run.py index 9f505535..b1ca6bce 100644 --- a/example/auto_compression/pytorch_yolov5/run.py +++ b/example/auto_compression/pytorch_yolov5/run.py @@ -18,7 +18,7 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression from dataset import COCOValDataset, COCOTrainDataset from post_process import YOLOv5PostProcess, coco_metric @@ -75,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(FLAGS.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] dataset = COCOTrainDataset( diff --git a/example/auto_compression/pytorch_yolov6/eval.py b/example/auto_compression/pytorch_yolov6/eval.py index 1d28466a..038a1f8b 100644 --- a/example/auto_compression/pytorch_yolov6/eval.py +++ b/example/auto_compression/pytorch_yolov6/eval.py @@ -18,7 +18,7 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.common import load_onnx_model from post_process import YOLOv6PostProcess, coco_metric from dataset import COCOValDataset diff --git a/example/auto_compression/pytorch_yolov6/post_quant.py b/example/auto_compression/pytorch_yolov6/post_quant.py index 84db4f98..97f46741 100644 --- a/example/auto_compression/pytorch_yolov6/post_quant.py +++ b/example/auto_compression/pytorch_yolov6/post_quant.py @@ -17,11 +17,12 @@ import sys import numpy as np import argparse import paddle -from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.common import load_config from paddleslim.common import load_onnx_model from paddleslim.quant import quant_post_static from dataset import COCOTrainDataset + def argsparser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( diff --git a/example/auto_compression/pytorch_yolov6/run.py b/example/auto_compression/pytorch_yolov6/run.py index 7e28e1f6..8676e7b3 100644 --- a/example/auto_compression/pytorch_yolov6/run.py +++ b/example/auto_compression/pytorch_yolov6/run.py @@ -18,7 +18,7 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression from dataset import COCOValDataset, COCOTrainDataset from post_process import YOLOv6PostProcess, coco_metric @@ -75,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(FLAGS.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] dataset = COCOTrainDataset( diff --git a/example/auto_compression/pytorch_yolov7/eval.py b/example/auto_compression/pytorch_yolov7/eval.py index 451301f1..f758f8f9 100644 --- a/example/auto_compression/pytorch_yolov7/eval.py +++ b/example/auto_compression/pytorch_yolov7/eval.py @@ -18,8 +18,8 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config -from paddleslim.auto_compression.utils import load_inference_model +from paddleslim.common import load_config as load_slim_config +from paddleslim.common import load_inference_model from post_process import YOLOv7PostProcess, coco_metric from dataset import COCOValDataset diff --git a/example/auto_compression/pytorch_yolov7/post_quant.py b/example/auto_compression/pytorch_yolov7/post_quant.py index a253e671..97f46741 100644 --- a/example/auto_compression/pytorch_yolov7/post_quant.py +++ b/example/auto_compression/pytorch_yolov7/post_quant.py @@ -17,7 +17,7 @@ import sys import numpy as np import argparse import paddle -from paddleslim.auto_compression.config_helpers import load_config +from paddleslim.common import load_config from paddleslim.common import load_onnx_model from paddleslim.quant import quant_post_static from dataset import COCOTrainDataset diff --git a/example/auto_compression/pytorch_yolov7/run.py b/example/auto_compression/pytorch_yolov7/run.py index b3df9639..f6ab7533 100644 --- a/example/auto_compression/pytorch_yolov7/run.py +++ b/example/auto_compression/pytorch_yolov7/run.py @@ -18,7 +18,7 @@ import numpy as np import argparse from tqdm import tqdm import paddle -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression from dataset import COCOValDataset, COCOTrainDataset from post_process import YOLOv7PostProcess, coco_metric @@ -75,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(FLAGS.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] dataset = COCOTrainDataset( diff --git a/example/auto_compression/semantic_segmentation/run.py b/example/auto_compression/semantic_segmentation/run.py index 4f4d4c56..6bc7d752 100644 --- a/example/auto_compression/semantic_segmentation/run.py +++ b/example/auto_compression/semantic_segmentation/run.py @@ -21,7 +21,7 @@ from paddleseg.cvlibs import Config as PaddleSegDataConfig from paddleseg.utils import worker_init_fn from paddleslim.auto_compression import AutoCompression -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleseg.core.infer import reverse_transform from paddleseg.utils import metrics diff --git a/example/auto_compression/tensorflow_mobilenet/eval.py b/example/auto_compression/tensorflow_mobilenet/eval.py index 85e5fdaf..bf0987e3 100644 --- a/example/auto_compression/tensorflow_mobilenet/eval.py +++ b/example/auto_compression/tensorflow_mobilenet/eval.py @@ -23,7 +23,7 @@ import paddle import paddle.nn as nn from paddle.io import DataLoader from imagenet_reader import ImageNetDataset -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config def argsparser(): @@ -93,7 +93,8 @@ def eval(): def main(): global global_config all_config = load_slim_config(args.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] global data_dir data_dir = global_config['data_dir'] diff --git a/example/auto_compression/tensorflow_mobilenet/run.py b/example/auto_compression/tensorflow_mobilenet/run.py index 86345ec2..aefd2941 100644 --- a/example/auto_compression/tensorflow_mobilenet/run.py +++ b/example/auto_compression/tensorflow_mobilenet/run.py @@ -23,7 +23,7 @@ import paddle import paddle.nn as nn from paddle.io import DataLoader from imagenet_reader import ImageNetDataset -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression @@ -107,7 +107,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(args.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format( + all_config) global_config = all_config["Global"] global data_dir data_dir = global_config['data_dir'] diff --git a/example/full_quantization/detection/eval.py b/example/full_quantization/detection/eval.py index 81a169d1..d6c7d49d 100644 --- a/example/full_quantization/detection/eval.py +++ b/example/full_quantization/detection/eval.py @@ -20,7 +20,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config def argsparser(): diff --git a/example/full_quantization/detection/run.py b/example/full_quantization/detection/run.py index aca12b26..fb0b9ad0 100644 --- a/example/full_quantization/detection/run.py +++ b/example/full_quantization/detection/run.py @@ -20,7 +20,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval -from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression diff --git a/example/post_training_quantization/pytorch_yolo_series/README.md b/example/post_training_quantization/pytorch_yolo_series/README.md new file mode 100644 index 00000000..c83bf76a --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/README.md @@ -0,0 +1,138 @@ +# YOLO系列离线量化示例 + +目录: +- [1.简介](#1简介) +- [2.Benchmark](#2Benchmark) +- [3.开始自动压缩](#离线量化流程) + - [3.1 准备环境](#31-准备环境) + - [3.2 准备数据集](#32-准备数据集) + - [3.3 准备预测模型](#33-准备预测模型) + - [3.4 离线量化并产出模型](#34-离线量化并产出模型) + - [3.5 测试模型精度](#35-测试模型精度) + - [3.6 提高离线量化精度](#36-提高离线量化精度) +- [4.预测部署](#4预测部署) +- [5.FAQ](5FAQ) + + +本示例将以[ultralytics/yolov5](https://github.com/ultralytics/yolov5),[meituan/YOLOv6](https://github.com/meituan/YOLOv6) 和 [WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7) YOLO系列目标检测模型为例,将PyTorch框架产出的推理模型转换为Paddle推理模型,使用离线量化功能进行压缩,并使用敏感度分析功能提升离线量化精度。离线量化产出的模型可以用PaddleInference部署,也可以导出为ONNX格式模型文件,并用TensorRT部署。 + + +## 2.Benchmark +| 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | +| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | +| YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | +| YOLOv5s | KL离线量化 | 640*640 | 36.0 | - | - | 1.87ms | - | - | +| | | | | | | | | | +| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | +| YOLOv6s | KL离线量化(量化分析前) | 640*640 | 30.3 | - | - | 1.83ms | - | - | +| YOLOv6s | KL离线量化(量化分析后) | 640*640 | 39.7 | - | - | - | - | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_analyzed_ptq.tar) | +| | | | | | | | | | +| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | +| YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - | + +说明: +- mAP的指标均在COCO val2017数据集中评测得到。 + +## 3. 离线量化流程 + +#### 3.1 准备环境 +- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim > 2.3版本 +- opencv-python + +(1)安装paddlepaddle: +```shell +# CPU +pip install paddlepaddle +# GPU +pip install paddlepaddle-gpu +``` + +(2)安装paddleslim: +```shell +pip install paddleslim +``` + +#### 3.2 准备数据集 +本示例默认以COCO数据进行自动压缩实验,可以从[MS COCO官网](https://cocodataset.org)下载[Train](http://images.cocodataset.org/zips/train2017.zip)、[Val](http://images.cocodataset.org/zips/val2017.zip)、[annotation](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)。 + +目录格式如下: +``` +dataset/coco/ +├── annotations +│ ├── instances_train2017.json +│ ├── instances_val2017.json +│ | ... +├── train2017 +│ ├── 000000000009.jpg +│ ├── 000000580008.jpg +│ | ... +├── val2017 +│ ├── 000000000139.jpg +│ ├── 000000000285.jpg +``` + +#### 3.3 准备预测模型 +(1)准备ONNX模型: + +**yolov5**:可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)。 + +**yolov6**:可通过[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)的导出脚本来准备ONNX模型。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx)。 + +**yolov7**:可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)。 + + +#### 3.4 离线量化并产出模型 +离线量化示例通过post_quant.py脚本启动,会使用接口```paddleslim.quant.quant_post_static```对模型进行量化。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行离线量化。具体运行命令为: +- yolov5 + +```shell +python post_quant.py --config_path=./configs/yolov5s_ptq.yaml --save_dir=./yolov5s_ptq_out +``` + +- yolov6 + +```shell +python post_quant.py --config_path=./configs/yolov6s_ptq.yaml --save_dir=./yolov6s_ptq_out +``` + +- yolov7 + +```shell +python post_quant.py --config_path=./configs/yolov7s_ptq.yaml --save_dir=./yolov7s_ptq_out +``` + + +#### 3.5 测试模型精度 + +修改[yolov5s_ptq.yaml](./configs/yolov5s_ptq.yaml)中`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP: +```shell +export CUDA_VISIBLE_DEVICES=0 +python eval.py --config_path=./configs/yolov5s_ptq.yaml +``` + + +#### 3.6 提高离线量化精度 +本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。由于yolov6离线量化效果较差,以yolov6为例,量化分析工具具体使用方法如下: + +```shell +python analysis.py --config_path=./configs/yolov6s_analysis.yaml +``` + +经过分析之后,会产出模型每一层量化后的精度,和较差层的weight和activation的分布图。在进行离线量化时,可以跳过这些导致精度下降较多的层,如yolov6中,经过分析后,可跳过`conv2d_2.w_0`, `conv2d_11.w_0`,`conv2d_15.w_0`, `conv2d_46.w_0`, `conv2d_49.w_0`,可使用[yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这五层后,离线量化精度上升9.4个点。 + +```shell +python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out +``` + +注: +- 分析后,每层量化的精度会默认保存在`./analysis_results/analysis.txt`,直方分布图会默认保存在`./analysis_results/act_hist_result.pdf`和 `./analysis_results/weight_hist_result.pdf`中。 + +

+
+

+ + +## 4.预测部署 + +## 5.FAQ diff --git a/example/post_training_quantization/pytorch_yolo_series/analysis.py b/example/post_training_quantization/pytorch_yolo_series/analysis.py new file mode 100644 index 00000000..118a7227 --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/analysis.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from tqdm import tqdm +from post_process import YOLOv6PostProcess, coco_metric +from dataset import COCOValDataset, COCOTrainDataset +from paddleslim.common import load_config, load_onnx_model +from paddleslim.quant.analysis import AnalysisQuant + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of analysis config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + return parser + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + bboxes_list, bbox_nums_list, image_id_list = [], [], [] + with tqdm( + total=len(val_loader), + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for data in val_loader: + data_all = {k: np.array(v) for k, v in data.items()} + outs = exe.run(compiled_test_program, + feed={test_feed_names[0]: data_all['image']}, + fetch_list=test_fetch_list, + return_numpy=False) + res = {} + postprocess = YOLOv6PostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + bboxes_list.append(res['bbox']) + bbox_nums_list.append(res['bbox_num']) + image_id_list.append(np.array(data_all['im_id'])) + t.update() + map_res = coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list) + return map_res[0] + + +def main(): + + global config + config = load_config(FLAGS.config_path) + + dataset = COCOTrainDataset( + dataset_dir=config['dataset_dir'], + image_dir=config['val_image_dir'], + anno_path=config['val_anno_path']) + data_loader = paddle.io.DataLoader( + dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0) + + global val_loader + dataset = COCOValDataset( + dataset_dir=config['dataset_dir'], + image_dir=config['val_image_dir'], + anno_path=config['val_anno_path']) + global anno_file + anno_file = dataset.ann_file + val_loader = paddle.io.DataLoader( + dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=0) + + load_onnx_model(config["model_dir"]) + inference_model_path = config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' + analyzer = AnalysisQuant( + model_dir=inference_model_path, + model_filename='model.pdmodel', + params_filename='model.pdiparams', + eval_function=eval_function, + quantizable_op_type=config['quantizable_op_type'], + weight_quantize_type=config['weight_quantize_type'], + activation_quantize_type=config['activation_quantize_type'], + is_full_quantize=config['is_full_quantize'], + data_loader=data_loader, + batch_size=config['batch_size'], + save_dir=config['save_dir'], ) + analyzer.analysis() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov5s_ptq.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov5s_ptq.yaml new file mode 100644 index 00000000..8fb861ec --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov5s_ptq.yaml @@ -0,0 +1,7 @@ +model_dir: ./yolov5s.onnx +dataset_dir: /dataset/coco/ +train_image_dir: train2017 +val_image_dir: val2017 +train_anno_path: annotations/instances_train2017.json +val_anno_path: annotations/instances_val2017.json +skip_tensors: None # you can set it after analysis diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml new file mode 100644 index 00000000..6d1e726b --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml @@ -0,0 +1,10 @@ +model_dir: ./yolov6s.onnx +save_dir: ./analysis_results +quantizable_op_type: ["conv2d", "depthwise_conv2d"] +weight_quantize_type: 'channel_wise_abs_max' +activation_quantize_type: 'moving_average_abs_max' +is_full_quantize: False +dataset_dir: /dataset/coco/ +val_image_dir: val2017 +val_anno_path: annotations/instances_val2017.json +batch_size: 10 diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analyzed_ptq.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analyzed_ptq.yaml new file mode 100644 index 00000000..c8570328 --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analyzed_ptq.yaml @@ -0,0 +1,7 @@ +model_dir: ./yolov6s.onnx +dataset_dir: /dataset/coco/ +train_image_dir: train2017 +val_image_dir: val2017 +train_anno_path: annotations/instances_train2017.json +val_anno_path: annotations/instances_val2017.json +skip_tensor_list: ['conv2d_2.w_0', 'conv2d_15.w_0', 'conv2d_46.w_0', 'conv2d_11.w_0', 'conv2d_49.w_0'] diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_ptq.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_ptq.yaml new file mode 100644 index 00000000..ab67a9df --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_ptq.yaml @@ -0,0 +1,7 @@ +model_dir: ./yolov6s.onnx +dataset_dir: /dataset/coco/ +train_image_dir: train2017 +val_image_dir: val2017 +train_anno_path: annotations/instances_train2017.json +val_anno_path: annotations/instances_val2017.json +skip_tensor_list: None diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov7s_ptq.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov7s_ptq.yaml new file mode 100644 index 00000000..0ad89a20 --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov7s_ptq.yaml @@ -0,0 +1,6 @@ +model_dir: ./yolov7s.onnx +dataset_dir: /dataset/coco/ +train_image_dir: train2017 +val_image_dir: val2017 +train_anno_path: annotations/instances_train2017.json +val_anno_path: annotations/instances_val2017.json diff --git a/example/post_training_quantization/pytorch_yolo_series/dataset.py b/example/post_training_quantization/pytorch_yolo_series/dataset.py new file mode 100644 index 00000000..7ddec29d --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/dataset.py @@ -0,0 +1,113 @@ +from pycocotools.coco import COCO +import cv2 +import os +import numpy as np +import paddle + + +class COCOValDataset(paddle.io.Dataset): + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + img_size=[640, 640]): + self.dataset_dir = dataset_dir + self.image_dir = image_dir + self.img_size = img_size + self.ann_file = os.path.join(dataset_dir, anno_path) + self.coco = COCO(self.ann_file) + ori_ids = list(sorted(self.coco.imgs.keys())) + # check gt bbox + clean_ids = [] + for idx in ori_ids: + ins_anno_ids = self.coco.getAnnIds(imgIds=[idx], iscrowd=False) + instances = self.coco.loadAnns(ins_anno_ids) + num_bbox = 0 + for inst in instances: + if inst.get('ignore', False): + continue + if 'bbox' not in inst.keys(): + continue + elif not any(np.array(inst['bbox'])): + continue + else: + num_bbox += 1 + if num_bbox > 0: + clean_ids.append(idx) + self.ids = clean_ids + + def __getitem__(self, idx): + img_id = self.ids[idx] + img = self._get_img_data_from_img_id(img_id) + img, scale_factor = self.image_preprocess(img, self.img_size) + return { + 'image': img, + 'im_id': np.array([img_id]), + 'scale_factor': scale_factor + } + + def __len__(self): + return len(self.ids) + + def _get_img_data_from_img_id(self, img_id): + img_info = self.coco.loadImgs(img_id)[0] + img_path = os.path.join(self.dataset_dir, self.image_dir, + img_info['file_name']) + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + def _generate_scale(self, im, target_shape, keep_ratio=True): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + if keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(target_shape) + target_size_max = np.max(target_shape) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = target_shape + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + def image_preprocess(self, img, target_shape): + # Resize image + im_scale_y, im_scale_x = self._generate_scale(img, target_shape) + img = cv2.resize( + img, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=cv2.INTER_LINEAR) + # Pad + im_h, im_w = img.shape[:2] + h, w = target_shape[:] + if h != im_h or w != im_w: + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = img.astype(np.float32) + img = canvas + img = np.transpose(img / 255, [2, 0, 1]) + scale_factor = np.array([im_scale_y, im_scale_x]) + return img.astype(np.float32), scale_factor + + +class COCOTrainDataset(COCOValDataset): + def __getitem__(self, idx): + img_id = self.ids[idx] + img = self._get_img_data_from_img_id(img_id) + img, scale_factor = self.image_preprocess(img, self.img_size) + return {'x2paddle_image_arrays': img} diff --git a/example/post_training_quantization/pytorch_yolo_series/eval.py b/example/post_training_quantization/pytorch_yolo_series/eval.py new file mode 100644 index 00000000..6705104b --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/eval.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +from tqdm import tqdm +import paddle +from paddleslim.common import load_config as load_slim_config +from paddleslim.common import load_inference_model +from post_process import YOLOv6PostProcess, coco_metric +from dataset import COCOValDataset + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + + return parser + + +def eval(): + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + val_program, feed_target_names, fetch_targets = load_inference_model( + config["model_dir"], exe, "model.pdmodel", "model.pdiparams") + + bboxes_list, bbox_nums_list, image_id_list = [], [], [] + with tqdm( + total=len(val_loader), + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for data in val_loader: + data_all = {k: np.array(v) for k, v in data.items()} + outs = exe.run(val_program, + feed={feed_target_names[0]: data_all['image']}, + fetch_list=fetch_targets, + return_numpy=False) + res = {} + postprocess = YOLOv6PostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + bboxes_list.append(res['bbox']) + bbox_nums_list.append(res['bbox_num']) + image_id_list.append(np.array(data_all['im_id'])) + t.update() + + coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list) + + +def main(): + global config + config = load_slim_config(FLAGS.config_path) + + global val_loader + dataset = COCOValDataset( + dataset_dir=config['dataset_dir'], + image_dir=config['val_image_dir'], + anno_path=config['val_anno_path']) + global anno_file + anno_file = dataset.ann_file + val_loader = paddle.io.DataLoader(dataset, batch_size=1) + + eval() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/post_training_quantization/pytorch_yolo_series/images/sensitivity_rank.png b/example/post_training_quantization/pytorch_yolo_series/images/sensitivity_rank.png new file mode 100644 index 0000000000000000000000000000000000000000..1eab297a1118f48e56d4ef496fcf5d18948016eb GIT binary patch literal 25389 zcmeIa2UwNY+AWGPYV12Iq9CwiP?4%4BHgYF6zL#ULj(Z@>Ae`E!~&Qlu%xT>t^(3a zBubRtm97YjBE9!BCTp3tv$MVD|Ic&pf1c};JR&TO zVPRQKIdMdph2`@W7M9Nrez6qav8`I4g1^MAk7`=0m>XK#owGd8a_XG*<%{Ol7mdI9 zr|o%5D`Rsr5g}2b-2(r-U~PTbN`NSD37M8W=$p4F? zWuuH)SiT9M966|JAJp6Aq@%iIZhml_<%<2H;iV6GO*#^+Px=U`@ZBy>h);O^iepT) z=z^e@3V)Jmg`dCr_P{{j`b690y`FN_qWx*-g(KAY-mcER^v`Jb;>p>HvNJbsEiSzx zB9cUJ`>CxWNPPI6$c?+M0k=watvZackUx#3Nl%wCzGTT}EfQmVna?AD$6xq!_mT5r z##bkw4XtK;wQKp`e{h{*!(#gt9t95yJl4h?yt?$sn~SN^YZ$+4^5eNp5;omFTeog~ zPuCb6ZqUrl=p1-kaP#)<_}<-&&z${g*_thiMSg<(*&*$Lx7hby_-7Z*?ZH=HeU&Q3 z9D<38o?gh-ufGnK47>Tr*{Wia7I6Lh?>~;!okw@1_vJ+X62|rZi;wfk&;1`iczEaI z$B!M~-PmE$JH3`$M(66%)lZ707=xs3y^lA+H~ex(Xy=n6p}3Ee5qqtjLm)vZ>{zPw zXA7S^kUIE8#mAzrQGDCBZ9DhwRXr83x8mm;>(6hx&SU)c*|nQqUh25;1}?6|WnA_L zj(c<2G>_%Y&n9s>4;&P|RImEePd}NoVk_SQ#aSkYrCcej@|yeKRrD`u)A&^+Z_!P1%}ef_j%$GDp}Z|a2f`)cSZSzD*f z4SKs>80oBV9H}^-YE;IR=DLJiRnpZ<-WiR2K58tW0BeYSh=)H=?g8i(f)Qt5Ud3Rnq(=b(`h>EmKAM zb6W#Two@GG!FJV;&K*wB%~DTmw7exh++^EuiCx+;g~zxt-GW>`)v$!!rnCH~OD&Au zlFw=1!SO16tll*E7C)uCDJ#u&;_WhDAD;__UYwb;V+}R2YW$)X$`9c7cX`gAKTn}@ zxr`_Y>Sb%%X6c@molV(i*4X)BW;k@d+FN!mdg*HRXxG`PSpB?Qc2Se?WTP^Z{^lH^ zvbt1L4Z3W@#eLU2JS6p}x)b{sEzCvh8mZycLXu8n@$cH&Sh=~A##`LJ6o)hu_?PqEoCHDSls=FHDcpRdPn**JLV zyG_NE7~U;-dxs~@<>P#xoxHy1>BTq7-9sHEfi<}&YU?T=Q?BI$-euLq{yC;sh*@BUp76(xEG$Rn>Iz)8I@Lly;eGW z-g!KDZvUlv;WSCBV|={j;j(i|(wTM|G=*SQP0i>s+3AzP64rNCtQcv%f6#Fh?07-;*f-$0^iQb_LkxK(Bsda&u3lSq^wjiFC8fmf+xgCktSZ>$up#LVjTTj2uGlU&)InEyF4w*FC{aAAcdV1c zH&3;<^jxOM%1@1a=kS4QI6GgYL=Bmk*QdYPwwk*4`E%yWz8o{eb4~75NlC&;jh#S? z&U^lk@mVJ*u1HMJR!hxet$a3Z)54(L&5MfeZrv*6$X=8d8>`Vz6J}v?w>;vhuhchK z%(kp{fQN-8U!Z+Fh`aX3XO~GJtY5uwiZ(9FQt-ce#(AoI0{(9YE%>)Bc^_er) zY_26z-Zicz@%y~921Hb?vPbg{HqVH>Png*JWS^h;<&|->$CkNwzgxS4tU+DkL zSNQJzyNS)8v8Y~e=g3~?cH(os7yAE@TIp?s54al*$MrH+J`0Qa`8VwiLE=^$SXm>G zYIwJAKc<~#X47SRi%roQsstC19fbhOvHQ;}1~`E_{L zqMlpqc@K-(7N+a6uzctZk z29`BlE?6Dm6)NMRC27}JjhI}Mpe;5#o;P1}Z?}PUA*bDi>9JnJ7S8+x`AWqN7hp+I}eie);;K z!kL+w4Qy;tBoNqVVOPeSNj7xseBd!sBG#faGX1`z6otdye_~I21h_4=e^k^4$dRAjYynddZdEV^PJ$v@(OmQKi)}&jsS zU)qvomxL5!Gu@je=`Frx4Mq9D0r#;FZ^DpbC{r!(dZR*JE4ypr_?>G5c~0Hib@1iO zm!(Ax4h{mXmnjQxghgt3+0Tnv#ve5f3GUjZDD60+=;)Y<;LFoRJFkT=sZ{E#pq9r5 z$7iQo=7nrNWY9yZVpO8`I*nbVKl?EAuGfRxwLC!7v>NjuNB?5Q3S~vbJ5GHY7%#+` z+U{pg^$7}!#}w~Gy2)`It&BgDtg5JZ|L32NP%6tG9IknIbY(<%c(g*WL}qtvoLtk~ zj0bO!te%ceyi;Gsak8-Jp>By(>bmnrM)FjW4f6Bz1G0cBSy+UYC-Jhd*sqfpYFW2# zo%M&;y8#X;Z{ECl($Fg>FK^}g#bG3XnHe3otjxn6YmDDKIZx%uYSQb@6g92o-M6n+ zH;jq)6#R5@Da!bhAuQq>o@!dB?$}`0;*wU~9v#iQ`TE<)myAbZsnfK@9&L6P<0Ctj zrKMxaRE)V3RhO4J4_joP5zJm)ws+yWn)uGIa8iVcG4uIvoP752A1uEIoJ9)cN@A@M z5-cqDrI67W_xx-8IS^vb_;Kp}?D3WDIfruA{_bX{u`$J_N1bMtW<7B&bNEFQm~{QiQC<^32~FdbPfAJY#r%|)~f*$HnH_jFKKs6V1S6|(&LR* z1pd|lQVD84I^eOEyA}b%WGIKRoZXY2G!8MNsiKZf(D1#`(9n19-lZUpt}qw7RKFK{ z=;T`I5k>!=vHq5%4<9}#mjDht|M_Rt`p+3pD|Gb0rL_a5h2Gp66v~Zf&z{L~u3fpZ z3gELSr(|lhXLPK=od4^u!y0>%laqrz^~K8c$4#$YyT&IVV69hLRAeiQOv(8uBLndp(DeEVvr}=m6h^ih83fcm^NE3Gh zhJ1zG)tlDRC)IIzKM#e#_+%sRVOw!=ag*k3hnUDFGhsWgg#+yRoOSpY|6}^truJHkfdA8E#&6`o@ zsRr)5gz7{u)*pCDQ2VJ*gHQPV`~ST5{r5goj1q?>_8G2yH~jtkZ&!2dJ6Tatfv8P0 zED4DC?z_*nZQm|Oy}pJscB7ycFA{`JcXfrDhOS!o7(VM0u3Q*u|8DuIP#N9chBOV_ zYOlq+BS()Mxln%ppr}n6W9+R4SatFzPLw}3C?ty$p|r>>jbBi(Wb(>#q|l+^xca+e zA+A%#tQ$6TzV+a)Et8#(mvNcm6%;&$8}4rV_#SE*VG?9i|5sfRiz-mW$m`+5hpJHuoY?pm zdTJ9`L#9#F^CH`0eX1hkf1B2jdim`$W%_$Moqy*}3bugi<;#yG?E2!Xt5uWx7xE|z zi?hx`tdRE;P_ndqqO79Ax_*5a@Y+rp89k@54@${~C4eH)N?%J)e@O0r{W$Ei#fwKL z>7o3Hc_jg&B!cqn6c;}er;$|q@aQdqhEW&Gg&p@+KXT;zzVtS)fZSV*wJ`9dpHoLe zvA+nZlCi%eLLaE7ndz2<%H}l=8yFba)alv*jA`@CKsy2~s}G5W&Ocj7ToE+fTjk?@tN}3Q{h#va}>I@og_4y7hBbjT6)lwpoN4@$uR4{KboIyd1|4f1?!Z zF6Gr*CW_`?cGjoH0e0^-cyW6L&HUiOgDY07u&k11=6m5M)lRufmM)D2{Udk!v>lh# z{w?@h0|8&RZ~vAxm`^z>alwc4tvBNy|qz}Liur|$SL{g@#As2S+>Yi)vtW{vAhqJ1WRgA zH>_Lt5VVEyhu4o+tXZRu85%`W&|SGvpd1fm+Uv@Q_xaOj%qIQr+_{geVO$wbc~x5K z*JP%lr4=xIc+-<7PvoR;CE&(by0t%DUB7FjUR%Fr8&*ME%2yGF;*c6 zinf=T$6H^_n*(U8~5LBzLjs+M!;&QxwyB5%lU9mE+jzH4BV7J~6PgMk_bfq(+WI%q)H+2=!;o z`69nR{&@A4{fa$xNp&-m!$moBgg~rK(obqYc}cy_T$NI14%nFOb)B)F?KO9+`sK&3 zQ!i9(9qn(?7t~5Y>Etm|qoHpM^oZ?8>fJ_ZP#fe=>ux-vX?;ozs`W6eAXDUNtt{IF zaD=N0?(mF$c%#yvH=m1K-CcW*`M{@@6WRlLNO=vyunKiILALE39kGaa1dTil3u_Yp zntA;nV{@D(lTiP5ze$dgG-O^PTP{`RGV|;F=NTi#X@H6^4U@Yb7jmhD9rN4^%t(8; zFYwg*9SYgui$31Cy>jT%E1#WU&9S@{$>+|$+CY2vk&2y6T!R+L#&x30!Wvj$&qjDikPcP#yG z$x7=sPcii!%#Ck<%|x(eRY7{E_!^I^BVu@bdA= z0oo~@JbC-6S(Db=j8NyH8+gk#V18@jG){w;UCAaSVAYK5N4Ol|Y}L{38WpfW!c&!J znDeHbK67Q|*T-)2+DulRmQ^z}H1r=v)CuYz-*-0Ooo{W*w=dUi+p&XRR8(#E?%iMt z2pvmUY!nai#gm@wqLs;IGI5A+d>|k~-R3gZtzTage?|mljy!d&uPF{F6pRg5CH5FA z!YlTNI))uHdz_A5uI>bvIo1{M_N!lh;T_irEQt(}cIw9VNEkXG0=g7?$M@bn zUJ9wBZnE<4b(?ecdT7+}Dz)KRw=Gha~>(rt2X<-}J{T^ANUufbUBJ#d>b?Cq`H9I(qbI{bK_U zR7`VX_kXRNZZ$XWS``r$MaZ=7hP0$`{bnb#=Qmk1hhFg$-Ws5PCkQ;K2gVTz&z>D> zZ*TYgtx9M&%0l9CXsL@bP?k1?W?p#SGs0nQw8tOfSyvT1e z`TpfyK3-nMZ@&%Kz8)J3qKNnFuMZ>gkz5PZhWaw@^=nPSMO6mdHOK4-bylVl{Aba2 zEFv;eJ>5BQ*SYUd`Uca6hlWgg=NM}~Uqo%XO2Nv?O4P7;6Ux6B#0`_GNYCL+V8za- z4Q5qXAFgp;rSmQzdCh|rRC-gqNA0tQ;At)hF$5`dg&pD;#UEIjH zLAoNckxJyr3S1@W`7Jh6P@$Pnj^uT@ktHDo=|u_uR^dPeZSMaw+b`}b>dO4?QfT>IR)(<((z%<5AH(*T4q0(MZ3Q^L9&V<2Uts)G#U-uY@s zgq-4*NAjkd`uU`H;3X_svZRI`5t)KN?Rsgze9E+_)1ip~C;BPk-~9XqWUGf-!;If& zsZmOx1p{_#A`V861T|&Z#o<~1m`O)_8#F>)#{AcCNon?twz!N3S5{Vn_qF?DHa7}& zQHO+>IXBfaIyc?lGcnjk4F`zp0v%9`yg`a`Pfp3&gaT$&9ultAKWPiK%M^4t`<`!q zKo!nM*|X}fhOUk8sscr*jw3%UsXS`q)4O`QfcPYYEsuszQ5ZSz|?RD_-cMdWGrm-A(0lqJdFhN@-*`n zlP0FPJYX#(-13cq!TIDju`{a?c4M8H)ownAYL|4&XP#D&pVX@1b=ZyLd4Gk3surhvW zfLBc`H<;}o|G1BYtIe+-uL!a>5=)`6QHN{<0&74LHA(Mae7@C!q&~&rjuLG2)5b&Z zUy@HiNg&f}_i^C%De_Y(-(#U(fSRC{ZU5MHZqf+!PY?2aEh;;D;v{ozd8@UzH{4ji zBL(n12Ta%$blWg@cNT)o$dp`4YuYZGC-1t4sV}s$V(y53I7~0ABc$}hhg)s%$ijIu zPY5d&c`7jAH;9DK&UT39Rh*~kMlM>s^uo_Se3=Oxx(y+wv$G1x;of^VzP% zjy$PCz%rg3YS+qj&XA`=Fg_Cl(thP@}AUEmi#KBzIAAJ+*)jUZXLPUBcWe`PJkY(f9)qNc8gECPL%WMj~gtgNhE!0FWRAAa~k z*Y?f%-|tk^)IK{^QtbNm2mv}j+wZoGa$06G9HnK-`J#k6s$q* z49f_jAFoEh*%z2lZ6ob8R!ufjaq-Dh!4frshd)M;XjBnjANPtzsa{%y`n;4DC|B7| z(@(^T=)AgCCVgZ%bIR7LCA4c_xDbsZHJFAKWyvXJ|EO9VDGQ=cMoJgsL0Q(S<73#n zFfDXtH|99mxy+8%Nm#u5o`*aoSvXp$ChdM-F|T@BHKDyhFs#XNUYtFDs2;SjF@2Fvwx0^v?@e;CgzHSY*B0@IGc(RPGebo~c@ z8>W*1!L!02kQ@k>;#7##87M}oX=Y6%@S%YAb8DcpVMD|;ol0gb%YUv42)}V-CBly8F?IvnZDxcJ`gZa@g2N&W6@$WC2@h1_-1{dq}6LRrPVR?Hp2hDrcsx$RU# z%UtA@#mo4%ZaqX6AC~c3zbn6yNVDcAWMx4n0TZ2K%gH&n30SKR=(ak|T%T?M6!O*K z>lZ*cbXP_c9hC%ndR0`!Kgx^`sz-rB=JD7*KmPcDpfOBZfV9(7LT`}A02KuKw}z_| z$q+eNNmElZrtRFhb9l#z1GX6%8N`7BqD}O2>wz33&hD1Hyx}R?*{;(ozWAcN-)-KB z)^UPDQNqF<=HIjDiKE>Yj5Q<^rv?6jOoL(Zy$3)KWWKAh;-kseD6-VBh*>!};*doe zf|J!_;er8nFGt}yGBaGp4$Phh4(-nmGz^F=-7B5?KZ@)C-UfD|d z5lgX7rpEh|=4PjR$VNcbPCycZhB25sqWQ*FoBkbsq)?m80QIu~9v=e2x`0TNa2$yy zyGbe!iu7seQ6Iv{7Z(@r$pKOTj8QYTWp0>5HLe}n4w%p=Ktga$uijq%W!G&!G0Yq> zmH_WB2(&pY&1isuz<$#1+wqVYIr=WYEOQ21T0!U~ zRQ)b5P%07k@oR^G9jUGk zZcjTqFF?~&!0OOK>>@J*3z8g_9rVNtPY&7?7fsjJA9%V3sarmMobgx(o+ExYWIH4S zK!k&BIsmK7opYO@tr9TAW#4*b5BEJq*5%8WOW1z6hixDJ;loAX@Q8>A@xTRDVY=n+ z7=r0u$B_#NXNu6n2ysg;_x<FRO!-QTAy)H=i)S6a*<8U1_qUXlSftYfedE6MGg+Cd92~WFEF}J(@A{xyc zwW*jD6CvjJq-Q&uBFOqL=;auMp2u1p##sKR(R=^Fh53Jyg1m8_X8y^Et1~+nzK*sR zo>hc@K&0J?$c{8n12)q1T_R)h^N*jgoP5e4M@)fa>ocsPK=VXSOxPnmr!)?8!013K z7Dxw`d2^$}NLfVP)pwnY;ORo`PJUrKY=wY`-FHvRR3IY|mZtiB6Jyo8doPdfUGz%C zs^g^R7IEbT5*u<$g688x1im8P7X@*1aA3U^V6=8T#mZd!# zwzRSWU?p}!yW32~+UFb`Eu6BhI;h4F2KjBI1O!fI+4VCr4tVZs_-;Ft05YgClbfK_^4P5CquE$gn#xvWLWSV{L<38@VuB{JNwhZH?pzy zkMqn9t*iK#(BJ8N#@NE*abE_V zi?|O+{6gvI6C5l^F{zIJ5jo40?!{O{`EgjKeM8vp4i%Y7GSsw zV$7~-GiEIH_}`DE!$CAXM@mx8SO65~R1|%L`W_q{1lGEEs!5P>gNaRlXR)C&Ax?ebO%C*6f`~T?Xx@iD)r07f@RJi08R&nL z(XOqe^dsQ^gY@+M^xd-m7WGpqfpUF=gLPv!v$MydilPND#(BWz^Xf5{yub1;{#!~I z`@Z+$Os2zJrar%bfa5un1m6cy`Ona&HaOm&wu+ZudJ*-2vA}TWx#js z1U)8Q+{`RFTU7+cpHu-1{sM^=u9^rh`uq31Qc~H^8o(2Gfg8UApsjLRI)d2y2tIwE zt*aKWe)~y9Gqc3up8P9|X_g;tD_0Sc|Jrhs)o8JQODWCssd3dYm?0wqfZ^(&oR_~= z@-NmJs<7lK1%=v2=L$A(awh2II2t1m5f=NKZyxd(Go$6O|0RrtAZe)3#69iHuVI4j z0;GyW9-?90*b!2Wv`bvyskkR3ELDJZFyN}%cIz3_uUASfO8fr%&yg2TLhDwxwS9_G z7F=Wy_q{v5zT0H;=B163{=Kj}JW_Q!4nKxp>=7ZYM!TwxqyB``Ou@57mbny{Vkz=q zsDDA9YSbscxkwB+Afh@Fzhs0HohXv+VeEsr{Rm zf8q8{7-%FCz;zC@a>6d5*46e8G90pAaG7G9c3NUbS*U*BBoUw?EC^;bB|_X;XrMt=4Cok4zu1_rL~Iy{s<0M~ zK&r>Dnfc0uF^igXRX?p<@OsAREdgC-#Wg%UtOAGPzDxD*mD{jM`Gkbj!TJOhLZOg= zgmltl4JY;7Caq*cPa@JZSr!Sh+5%(x(r6mpOjH}a0@fWbZdu|q2DtQ<-~t1Ll9I*_ zB(HMh$dQ;O%|s_ckspVLT(N3ZbwXNmERtBXc}p&6^oppXGH2kR5X<3<)^ z+p`E0Z>vr;Y~!HOG!zsRV%iW~SFBx|H1br$uy~;Z1-~8x7**P4R2ozvQYIA^Ax|@D z&A&no2d}{gp$;G99bJ{N65Pd3Uf#}i{a-S!?(82CGPx0*C0x>uK2wjrALb4z2OeKx zJ`M@A%iC=xj0h^6oVwka@7{gG%EneymNcR}IzCP-czd~@$&RxKosBVqhZ@u%X?y|f zB>X8{Tgs4-@befPS?HZwq^T+O>`$MS1c@it_A%9f>`)}8zvLcdzyFG#Kdw6Ybw_D1uD!L%s!TdU z;_i;qwP1UBC}@kQgyPRKlWK|1iDM)7rO z!bXtSEiEq2z91|mr9;$Lj_u5yt78;1{*mR*{(rGX{Xcr^MXA#6QD~#2f>qZjx|>c5 z8?g2_sLAj)w2iHLq80UUV{ky1YSj5ug`;Fc?5jx)ZAZy#K?CA_ZO86*amxR%wH(N^ckYoU`Ra zdh*diG0DD*6@&3!lG@hJ$!(<@|M2Y+e7m{+@x$`+@>?MV%Y>Ml_m4JL9tzxR0a6Z6 zzWeUG&ho*V%u!Ut_cb|XI}8$_eLM9p+8=m@$-MN z&?Z^D`}sya*CH0h3ArcxriaqC`8mJ-HlDI)4k}xyi*iRfGrFh zUEAg*`WwuCoX^I`4 zFd1u0`Psn6A-ml|>)qpe);|0AvD!l88-*qs14iAkZI7qi8YlPnp1OGb+@@6N&rO;U z2)ah8A@|TrT{T8cbsyi-86(mofStg|+&AbjYWF#xUdbHd8T)=*dhw{Oui5Zc_ zkBJ=z7P%jJX!MN9jzG9Yg}P|6JUlyg^bp?$_@f1_7W%f3O=4y@-)tNAd^wRs$y6fx z5=Lhfc9c$Zk4^RH&vz>yr^2E{+8p7?kRB;N>^Sh#GU7?1nY%Tkm8}}yAVTd33!g@- z6!~H5H_ulPNegbh2lBy2Ar8mj10q6&%)h7M9WL>ji2l*$Einq+viMB%T zIuv+3YQ08U3r(!;@85Sqaxi)Q=mEhLP{5FP`H&`%`z&LsqLsq-6e6Lsva|D1gyt9Z zKbu6yTO7E!Cb#+7DTp2L{qzs@Clik$-i0@mjO%N;6QNfUeG#&FIkFFVNSee#Bb$9X z08=DXA7YDA1_sj2?7c?h$R_KmQxJd1q}x=T5;KN0KmQ=>y@^ebM3t z+XB)tRJgl(&Vxri(RzZPTF0oY_pVy!T|*t0QIMBX*T0>`6_+zgR^Ze(6C$I}2Q_uY z?W7f)I2X&1N^TdPw`hwjWQmGIT}hfGL@0<*NN#;(eUTS_`Q7tw~t4r$;8{LSQ};LqCdo3$BKn!wV=j{fT#Qj%!6r9tssZkTM16F&VEerj5 z4?GrRWY8znx>~-Jt@5?1D;r6or#&l66`MavD|~xd)n3;j6GCzyt;tLbfnBGMT%)wb z;JDY8-Hu})jI#%xt-O9twby9MxoFp_HY}K9n5f&(RB_8Nc)I206V0;+k1aef+K7$5 zH8f-=R6ek~G5Olx-(Sd1C@2qx(--jmtS*>nA3vk!)>~xCeqAX-e|k8r36OFSXqf{D z>IIw*M%jqa$6Q@=i-L#qyn(JO-~cx%vwAy=IJqY=Z_W|7ZhQUM;20L9ZeMzjMg1=O z_}oqAEA`~qz>=vBmziT9gq zP--|zX;20&?WK5gq*XNs^fR`t8sIl^K*KSm)+M96e;Pia;CV5)08YBPCgyc#!I&@w zJC9AfuSP5`L<;*8Cw_o_YK<7jSGEC>1U30K9-dCp4hbDG4gxS*gG@<7xlQSO(5CAT z5dFo4A*Arlu3GZ>=lk6*k=}ygGFf>l${v-7=@(_z?Tw=%f@k{a_1bQcm-YpcEr>GP zPr?<_D_==6%vr>wOM1EFImK2T1~3(}TUBs*F87PIYwKWK-An9{^C#si5&X7ONcR}b zay>!4Gi6zEQ122n^j*|I(mVDfq*>!px^X4~aoET0ExR^M^YMvBokew>bW!(pKpPd9 z_U)s6P1>ZD0-+1S9kE36!Y(GRG9F}#JogKML#74|GO%iDw8&kvD+S=#agfx-F%Jn<^1KbfFt_KG~)PQ{Kobj$j9#z%R z`C#c0+Vv`<$4Ki*b6INI+(~Z^F$FxUU9*ExOlTKsMPs}hAv;8eMEFsNL*;zp^eFca z{3__?J8x(xq5mD4S09q?1LBdmO?3NCPv? z;Hj2dbVH6)T4F{A`{wHp5S+y4-x+K==IdIb=hQ1okW-G+xH>Qo>7*h3XyEMPD*H3^ zTRcd!^~Atib{Xe%QW4W6$*B`0B0#(*ec4l3%W)+2qnkz@4cFd!9pO2Pc*pwXmk)mU z;fFnS!7aI2-;iSxYGQpI$NO}4 zAsZndG{T@>BcvV!Px7qFY;V#G`EyJK&O-qBs9igJ@nh`&udG2H+#2`roCj~N) z#cQP$J6hg6Ct;eFVlUp07RxD|Mj-$zG5p?L#Jzy%S&atDQ~rC7()3+3+b5E(7sM17FZFms%y1Vwu(CHnwT?b zF2+I#bK=mk)E0_l1Vg^2G2+XJDI(q6(DXnQC*Vj8O>gP3dqoMjK_^OXQ>;utzSBSc z_@j_mv&*>dYHpdv_F?{SHeClPu^kVMMg(1g``|Ap4^6s!gyyfoyi8z!X=vAQGU-tg z6RVqK+p3Y?#=r=Gq%IQ!=gOBi3jN)A#x36|PE33)o`g_?~H@R6;=|ln%8QHSMXo z);Riv4>hZWQ4xnEc%~W>IboKl9K7QnUq3l2qUll^;byov`;40GPhQNy?AWuXQc$Ek zCSxOKfjz?YZT`gcXA>pnt_EnY{CNT|IJnNHzq!dkl2LsU`acYr`#R78h%-+FE_UJE ziH8;itHm~9hgxNjQj*lCkwY>M_U_tMK3QfRPYPA_o<_jBQnxuIc27WJ11yVU@Dh6h zK!*0Co;B~a=YV_8H@17n>xw}y%%Lc(zkOLn(q@~kD+KL-J2d{F%|XP$l#a@!26>A# zDUriy06!{`2}JD&TKRI$q}DJCWx`THJ|%835bF|`TiuCEq$2Ei4@}5Ol^mC|;>aYb z2X$eKWWK%5F_OP{?TM-O$8*TSb(p;FA(w|LR;XgXO)Y(ScSrEN6pSf2cS3^{9G2+& zf-?(SYSJ2|w}WY61KN#RSsOG8ler zZ$5fx(Z_ z1Y((F`sPVQQHzHcy#$}kYBGEO{=JWcZD;vcVMwUQu;Zkk5^ACPWE3Ie+?vxZ!iGaR zfxa@VjA%^uK+R~+s)7kjD3G_~CSvEP0Gi-y2~jAqjo*J)O#-#e3#5wktJwGcasiz@ z79GW#={Bge1}Cx1hibjj&P`sILmgc{_|w8GfAjQGpaqy&%cKT0G=CglFr4SKsg-XX z(A~Z7nAeuB$An#oGL(>~35lY+uV_mPGRSV5Zq?sLCeaUX>%OIHxh@_#dQ`1v%-V!J zPkgxpGOfHz@P3PDyKd{xMHPWzY#V0GmwCZZ$#FD1Y`%n{4&3sxBM%|jAIJ~K)X}J= z5`;1#+z690DeL}?Nm=-I_SKSIu@gcs2UfJWPABjf559YjK5gsVuD2xSX4;sd@bUCw z`*>B7@gQB_^?J!^_q21uurOFqBKM7j(Jc@rjc*>d1^)K!<|}J;29y>mRM?V?kpWBF z>h#@E19Z0!J-fQVuq=$Uade)evs&ktm$~h)mjFBT>y;Vb=5(G2+x|}qENH}~uO|es zX=WzGDNHt^U31ziRl3UNu4=RYbQSxKKyRnkE6c3mj_xL&1RfXklCn$L@A`NtMl)Xg zckO}RfqTy1Lyw|4VTi4XC?;O zMGSAF6~?rv=Grm_O=o8#Ge2WRseJeR7QC)82pAM9Tv1VgZCTa#+x$hCcVX~lSkVZy zzvs&k7K7_uyE#n?(JX&{!^Vy5s1yHKjBP?Rj%te_OSp8}aMSfK3B&!Q;4QD!?a~Ol zUzs;1VQ3Qhrw1gPu+`sO=GCrM=scC90)?aHk8ky+T?arKt4;l)&-l&0Y@C4taDq~g zh|5%R`U%jugjpXJ%^%pv>XQhGzkeMLG zdSnT54-$tv;aaN!h~3deqa}hqz2_B3GSKfQ2B{9e@O%CfL0Eg>%1Rj%Kyl�XS6e zwg+oZ5I6Xg>mPSsg9BkBP|Izj00*+~I(xvE!&=Jf6CNUesH6~);}CjIBHIkGHsfm- zK!a5)H9Uy-CQ8wMc$0(6}xyxQf1Ls+1h)ulMK<~G_@HbG21W^bRq z8;Oc;gCZrWEt0I@-*d75N+CDb5AG9aX-z@tX8_22Jv5Y?lbdAlZfnAHp{6ef(Xrq1 z{SOuB4{8HY#6Nd^u$AN#`;FH9yT$dkO5p+Y79wDgDvnrVJHxyy2d|L9`@jXUL z=kr=g$;!6aoV-&Q@Q(yy2WUia+x9v&RFSxt``dXJ%l!4n~>av?rgYS5vP>;k8MRtJqE*x#Y& z&}cRKB7^|dPYT?Hpr1FPsnSlb`O&$nU#?$&xUdM%@KQGz6)y+89mn8#`7t!Kf97FZ z^Cv}|jRGysI?&sTwtyVs09qigRRKrRp!G{4wc!ugC!~fzqTk_}iKTh7&@S0MVr4{w#K_c5eO7y8$h(eYw+SAk^fk zpcrM)%AspjB5IwTjv&-nlK~>e7s(GwE>vf@yKziXu#^MUi{sPYB{4Kb>yopHFk|Ta z{F;FEDu944dj~Apj5w>gDrI5T-~iLVN0`ow$=O`CSvl3_D++d1!<4@OAlXP1knC?D zq$05*kkprPuu$X^V*d&=`X6^BqU_W*K~QMvBq4F^yLeeoFV8i{0_{Cj%;mUa`}P-__}=gkz|(E` zzWyITh{8*a=@nnBTqy$a&S<2kt_)nuE}Ug~=;MMM3K7(3PVVm1(fW^2(tIGUmzML; zmN1ZTOH~qq58SF2!S2A-M?YQ?w<4LTv0B>H8jYwT#nMu^L8(I zy9*Fl;Zt9vJ&x1RE5YG{;o;wZLaJ)U@sKP@m2laS_vGM!n2Mry(I-ACW-O|Ck-Ox- zeaD};`WD2*9G1^n?502=#R1(@sKi18Y4VQO@=sd9tABH(zWvp9_WLvLKEZ(Y=ngrZ z@2js4f}KM5h%yAa>ZsG)Iglht3rtLlBx(}0`{O?lrtKlgXy78JuYtm=A^=B*2xN-2jL7Y@c183RrXSFbn5Zf~*nUQwIJ$0-><(mvMM^1`J_% zAVgF+s+~^we3a3ypNg|+(3|qiK^KBEs`W^8WOWfb7=i^%>EF}~KzfG4S_xE178!N59k4_gYZ*>W?w=5-0bRir>LMUL{WW959BWYzQZb7?d$NHY3(~9_YQ1tXY-=f=;jpKnNl}EjUQypADleftL~>>*ng; zgNlV5z=`9S`-voQ?DnPyz=NGvmT{HCOr_WEPv#MhO>!7ryHJ@c&K7bYy<3U7QzkeY zLJr5v-NEUY#IBRoGMm`gfMcCvAVJq7%#sry$r+>AD?~TIp>j1@cILBqohYOqk!CQQ zfux1osREJ2l!Di+M!zH&XQ%FpFb{IB6FHoh_|Ql<9MSeB=U_BRUnxuuMv6XQJKnuA#dhG4+FE!%mazx9d zC?x*b$uc)}P+1Xl+2LZ50GykJcOP7cr<l2348q4Men^kfM`ujX?4u$xi|mFrbb6HYn~x+O*-!v40|p+y@2 z!h}a%t>{V9blLeXsWEb77xphYyaZG2?}j7Wb>b)@a*Pr*DZ9s6@Z+6b5keZKT9=@K z9Z%f9r_en8?NUmmERMqF6C=&BIJ`rTun4#rPK)tbQ6|X8x6ma_PFfi6Yr;v_v~!zE zd&sHdW_frE5ng2FQ>$U$B}aBO_6S`h=U7owAR1ZHi*NuFri>ib(AnAfs(%;eX8A%S zt9~*taG4zQ3^$%4Y=nH$<8UZb=hA+VNFrvuVgBid*N6MR!ZBRh32$6sROXx?_Hx2Phs_wp2RE#Dvll0>vF)v)MEGI0M%Q2gq5VxY`bO6 zasLL;ktc|SB2H8rr>v1fbkTncfgrWf3Ke7|#zJ%)a!LZBQxIeHa7snYK*7ew_8F=f zy3hBl3VW7>!o^<#2Y?(n%A5u?24YfA10h%85J5hve!NE-4>_x$*aos9IPAbzTjp_& zIZsUS^aHi$uv?SvJpiE7vheZv_u#GYQ`q-kKI=d3A~`=>0u}3PN_(>%A|wC8vB!!< zB(4!(2)rQV%k(XdG`LY+a_)(? z;S6zTB1q>JTcB(v{ip0O!oUcca;ae-4kOVcwo{Doq%KUKNRJDr-eYe*0f!TVqyKnu z)(1JK3Z9)tTaqtg3k-E8HV 0: + max_score_index = idxs[-1] + max_score_box = boxes[max_score_index][None, :] + keep.append(max_score_index) + if idxs.size == 1: + break + idxs = idxs[:-1] + other_boxes = boxes[idxs] + ious = box_iou(max_score_box, other_boxes) + idxs = idxs[ious[0] <= iou_threshold] + + keep = np.array(keep) + return keep + + +class YOLOv6PostProcess(object): + """ + Post process of YOLOv6 network. + args: + score_threshold(float): Threshold to filter out bounding boxes with low + confidence score. If not provided, consider all boxes. + nms_threshold(float): The threshold to be used in NMS. + multi_label(bool): Whether keep multi label in boxes. + keep_top_k(int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + """ + + def __init__(self, + score_threshold=0.25, + nms_threshold=0.5, + multi_label=False, + keep_top_k=300): + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.multi_label = multi_label + self.keep_top_k = keep_top_k + + def _xywh2xyxy(self, x): + # Convert from [x, y, w, h] to [x1, y1, x2, y2] + y = np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + def _non_max_suppression(self, prediction): + max_wh = 4096 # (pixels) minimum and maximum box width and height + nms_top_k = 30000 + + cand_boxes = prediction[..., 4] > self.score_threshold # candidates + output = [np.zeros((0, 6))] * prediction.shape[0] + + for batch_id, boxes in enumerate(prediction): + # Apply constraints + boxes = boxes[cand_boxes[batch_id]] + if not boxes.shape[0]: + continue + # Compute conf (conf = obj_conf * cls_conf) + boxes[:, 5:] *= boxes[:, 4:5] + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + convert_box = self._xywh2xyxy(boxes[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if self.multi_label: + i, j = (boxes[:, 5:] > self.score_threshold).nonzero() + boxes = np.concatenate( + (convert_box[i], boxes[i, j + 5, None], + j[:, None].astype(np.float32)), + axis=1) + else: + conf = np.max(boxes[:, 5:], axis=1) + j = np.argmax(boxes[:, 5:], axis=1) + re = np.array(conf.reshape(-1) > self.score_threshold) + conf = conf.reshape(-1, 1) + j = j.reshape(-1, 1) + boxes = np.concatenate((convert_box, conf, j), axis=1)[re] + + num_box = boxes.shape[0] + if not num_box: + continue + elif num_box > nms_top_k: + boxes = boxes[boxes[:, 4].argsort()[::-1][:nms_top_k]] + + # Batched NMS + c = boxes[:, 5:6] * max_wh + clean_boxes, scores = boxes[:, :4] + c, boxes[:, 4] + keep = nms(clean_boxes, scores, self.nms_threshold) + # limit detection box num + if keep.shape[0] > self.keep_top_k: + keep = keep[:self.keep_top_k] + output[batch_id] = boxes[keep] + return output + + def __call__(self, outs, scale_factor): + preds = self._non_max_suppression(outs) + bboxs, box_nums = [], [] + for i, pred in enumerate(preds): + if len(pred.shape) > 2: + pred = np.squeeze(pred) + if len(pred.shape) == 1: + pred = pred[np.newaxis, :] + pred_bboxes = pred[:, :4] + scale_factor = np.tile(scale_factor[i][::-1], (1, 2)) + pred_bboxes /= scale_factor + bbox = np.concatenate( + [ + pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis], + pred_bboxes + ], + axis=-1) + bboxs.append(bbox) + box_num = bbox.shape[0] + box_nums.append(box_num) + bboxs = np.concatenate(bboxs, axis=0) + box_nums = np.array(box_nums) + return {'bbox': bboxs, 'bbox_num': box_nums} + + +def coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list): + try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + except: + print( + "[ERROR] Not found pycocotools, please install by `pip install pycocotools`" + ) + sys.exit(1) + + coco_gt = COCO(anno_file) + cats = coco_gt.loadCats(coco_gt.getCatIds()) + clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)} + results = [] + for bboxes, bbox_nums, image_id in zip(bboxes_list, bbox_nums_list, + image_id_list): + results += _get_det_res(bboxes, bbox_nums, image_id, clsid2catid) + + output = "bbox.json" + with open(output, 'w') as f: + json.dump(results, f) + + coco_dt = coco_gt.loadRes(output) + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval.stats + + +def _get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map): + det_res = [] + k = 0 + for i in range(len(bbox_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] + k = k + 1 + num_id, score, xmin, ymin, xmax, ymax = dt.tolist() + if int(num_id) < 0: + continue + category_id = label_to_cat_id_map[int(num_id)] + w = xmax - xmin + h = ymax - ymin + bbox = [xmin, ymin, w, h] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': bbox, + 'score': score + } + det_res.append(dt_res) + return det_res diff --git a/example/post_training_quantization/pytorch_yolo_series/post_quant.py b/example/post_training_quantization/pytorch_yolo_series/post_quant.py new file mode 100644 index 00000000..cac752a9 --- /dev/null +++ b/example/post_training_quantization/pytorch_yolo_series/post_quant.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from paddleslim.common import load_config, load_onnx_model +from paddleslim.quant import quant_post_static +from dataset import COCOTrainDataset + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of post training quantization config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='ptq_out', + help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + parser.add_argument( + '--algo', type=str, default='KL', help="post quant algo.") + + return parser + + +def main(): + global config + config = load_config(FLAGS.config_path) + + dataset = COCOTrainDataset( + dataset_dir=config['dataset_dir'], + image_dir=config['val_image_dir'], + anno_path=config['val_anno_path']) + train_loader = paddle.io.DataLoader( + dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0) + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + # since the type pf model converted from pytorch is onnx, + # use load_onnx_model firstly and rename the model_dir + load_onnx_model(config["model_dir"]) + inference_model_path = config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' + + quant_post_static( + executor=exe, + model_dir=inference_model_path, + quantize_model_path=FLAGS.save_dir, + data_loader=train_loader, + model_filename='model.pdmodel', + params_filename='model.pdiparams', + batch_size=32, + batch_nums=10, + algo=FLAGS.algo, + hist_percent=0.999, + is_full_quantize=False, + bias_correction=False, + onnx_format=True) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/paddleslim/auto_compression/__init__.py b/paddleslim/auto_compression/__init__.py index 990ad37b..cfc26259 100644 --- a/paddleslim/auto_compression/__init__.py +++ b/paddleslim/auto_compression/__init__.py @@ -19,8 +19,14 @@ from .config_helpers import * from .utils import * __all__ = [ - "AutoCompression", "Quantization", "Distillation", - "MultiTeacherDistillation", "HyperParameterOptimization", "Prune", - "UnstructurePrune", "ProgramInfo", "TrainConfig", "save_config", - "load_config", "predict_compressed_model" + "AutoCompression", + "Quantization", + "Distillation", + "MultiTeacherDistillation", + "HyperParameterOptimization", + "Prune", + "UnstructurePrune", + "ProgramInfo", + "TrainConfig", + "predict_compressed_model", ] diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b64152b9..a907e0c9 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -29,13 +29,15 @@ from ..quant.quanter import convert, quant_post from ..common.recover_program import recover_inference_program from ..common import get_logger from ..common.patterns import get_patterns +from ..common.load_model import load_inference_model, get_model_dir +from ..common.dataloader import wrap_dataloader, get_feed_vars +from ..common.config_helper import load_config from ..analysis import TableLatencyPredictor from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes from .strategy_config import TrainConfig, ProgramInfo, merge_config from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config -from .config_helpers import load_config, extract_strategy_config, extract_train_config +from .config_helpers import extract_strategy_config, extract_train_config from .utils.predict import with_variable_shape -from .utils import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir _logger = get_logger(__name__, level=logging.INFO) diff --git a/paddleslim/auto_compression/config_helpers.py b/paddleslim/auto_compression/config_helpers.py index ebc5b45c..b1e426cc 100644 --- a/paddleslim/auto_compression/config_helpers.py +++ b/paddleslim/auto_compression/config_helpers.py @@ -14,42 +14,7 @@ import yaml import os from paddleslim.auto_compression.strategy_config import * - -__all__ = ['save_config', 'load_config'] - - -def print_arguments(args, level=0): - if level == 0: - print('----------- Running Arguments -----------') - for arg, value in sorted(args.items()): - if isinstance(value, dict): - print('\t' * level, '%s:' % arg) - print_arguments(value, level + 1) - else: - print('\t' * level, '%s: %s' % (arg, value)) - if level == 0: - print('------------------------------------------') - - -def load_config(config): - """Load configurations from yaml file into dict. - Fields validation is skipped for loading some custom information. - Args: - config(str): The path of configuration file. - Returns: - dict: A dict storing configuration information. - """ - if config is None: - return None - assert isinstance( - config, - str), f"config should be str but got type(config)={type(config)}" - assert os.path.exists(config) and os.path.isfile( - config), f"{config} not found or it is not a file." - with open(config) as f: - cfg = yaml.load(f, Loader=yaml.FullLoader) - print_arguments(cfg) - return cfg +from ..common.config_helper import load_config def extract_strategy_config(config): @@ -101,12 +66,3 @@ def extract_train_config(config): **value) if value is not None else TrainConfig() # return default training config when it is not set return TrainConfig() - - -def save_config(config, config_path): - """ - convert dict config to yaml. - """ - f = open(config_path, "w") - yaml.dump(config, f) - f.close() diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index 30276bbf..8a6c7db2 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -23,7 +23,7 @@ from ..dist import * from ..common.recover_program import recover_inference_program, _remove_fetch_node from ..common import get_logger from .strategy_config import ProgramInfo -from .utils import load_inference_model +from ..common.load_model import load_inference_model _logger = get_logger(__name__, level=logging.INFO) __all__ = [ @@ -52,7 +52,8 @@ def _create_optimizer(train_config): optimizer_builder = train_config['optimizer_builder'] assert isinstance( optimizer_builder, dict - ), f"Value of 'optimizer_builder' in train_config should be dict but got {type(optimizer_builder)}" + ), "Value of 'optimizer_builder' in train_config should be dict but got {}".format( + type(optimizer_builder)) if 'grad_clip' in optimizer_builder: g_clip_params = optimizer_builder['grad_clip'] g_clip_type = g_clip_params.pop('type') @@ -444,9 +445,8 @@ def build_prune_program(executor, "####################channel pruning##########################") for param in pruned_program.global_block().all_parameters(): if param.name in original_shapes: - _logger.info( - f"{param.name}, from {original_shapes[param.name]} to {param.shape}" - ) + _logger.info("{}, from {} to {}".format( + param.name, original_shapes[param.name], param.shape)) _logger.info( "####################channel pruning end##########################") train_program_info.program = pruned_program diff --git a/paddleslim/auto_compression/utils/__init__.py b/paddleslim/auto_compression/utils/__init__.py index aa4f3ec0..e3c3a49d 100644 --- a/paddleslim/auto_compression/utils/__init__.py +++ b/paddleslim/auto_compression/utils/__init__.py @@ -14,11 +14,5 @@ from __future__ import absolute_import from .predict import predict_compressed_model -from .dataloader import * -from . import dataloader -from .load_model import * -from . import load_model __all__ = ["predict_compressed_model"] -__all__ += dataloader.__all__ -__all__ += load_model.__all__ diff --git a/paddleslim/auto_compression/utils/fake_ptq.py b/paddleslim/auto_compression/utils/fake_ptq.py index fbecc224..e86dd848 100644 --- a/paddleslim/auto_compression/utils/fake_ptq.py +++ b/paddleslim/auto_compression/utils/fake_ptq.py @@ -12,7 +12,7 @@ except: TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type -from .load_model import load_inference_model +from ...common.load_model import load_inference_model def post_quant_fake(executor, diff --git a/paddleslim/auto_compression/utils/load_model.py b/paddleslim/auto_compression/utils/load_model.py deleted file mode 100644 index 637e808a..00000000 --- a/paddleslim/auto_compression/utils/load_model.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import paddle -from ...common import load_onnx_model - -__all__ = ['load_inference_model', 'get_model_dir'] - - -def load_inference_model(path_prefix, - executor, - model_filename=None, - params_filename=None): - # Load onnx model to Inference model. - if path_prefix.endswith('.onnx'): - inference_program, feed_target_names, fetch_targets = load_onnx_model( - path_prefix) - return [inference_program, feed_target_names, fetch_targets] - # Load Inference model. - # TODO: clean code - if model_filename is not None and model_filename.endswith('.pdmodel'): - model_name = '.'.join(model_filename.split('.')[:-1]) - assert os.path.exists( - os.path.join(path_prefix, model_name + '.pdmodel') - ), 'Please check {}, or fix model_filename parameter.'.format( - os.path.join(path_prefix, model_name + '.pdmodel')) - assert os.path.exists( - os.path.join(path_prefix, model_name + '.pdiparams') - ), 'Please check {}, or fix params_filename parameter.'.format( - os.path.join(path_prefix, model_name + '.pdiparams')) - model_path_prefix = os.path.join(path_prefix, model_name) - [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( - path_prefix=model_path_prefix, executor=executor)) - elif model_filename is not None and params_filename is not None: - [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( - path_prefix=path_prefix, - executor=executor, - model_filename=model_filename, - params_filename=params_filename)) - else: - model_name = '.'.join(model_filename.split('.') - [:-1]) if model_filename is not None else 'model' - if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')): - model_path_prefix = os.path.join(path_prefix, model_name) - [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( - path_prefix=model_path_prefix, executor=executor)) - else: - [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( - path_prefix=path_prefix, executor=executor)) - - return [inference_program, feed_target_names, fetch_targets] - - -def get_model_dir(model_dir, model_filename, params_filename): - if model_dir.endswith('.onnx'): - updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer' - else: - updated_model_dir = model_dir.rstrip('/') - - if model_filename == None: - updated_model_filename = 'model.pdmodel' - else: - updated_model_filename = model_filename - - if params_filename == None: - updated_params_filename = 'model.pdiparams' - else: - updated_params_filename = params_filename - - if params_filename is None and model_filename is not None: - raise NotImplementedError( - "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first." - ) - return updated_model_dir, updated_model_filename, updated_params_filename diff --git a/paddleslim/auto_compression/utils/predict.py b/paddleslim/auto_compression/utils/predict.py index 01ef6a90..5b8c6adb 100644 --- a/paddleslim/auto_compression/utils/predict.py +++ b/paddleslim/auto_compression/utils/predict.py @@ -4,7 +4,7 @@ import paddle from ...analysis import TableLatencyPredictor from .prune_model import get_sparse_model, get_prune_model from .fake_ptq import post_quant_fake -from .load_model import load_inference_model +from ...common.load_model import load_inference_model def with_variable_shape(model_dir, model_filename=None, params_filename=None): @@ -53,7 +53,7 @@ def predict_compressed_model(executor, latency_dict(dict): The latency latency of the model under various compression strategies. """ local_rank = paddle.distributed.get_rank() - quant_model_path = f'quant_model_rank_{local_rank}_tmp' + quant_model_path = 'quant_model_rank_{}_tmp'.format(local_rank) prune_model_path = f'prune_model_rank_{local_rank}_tmp' sparse_model_path = f'sparse_model_rank_{local_rank}_tmp' diff --git a/paddleslim/auto_compression/utils/prune_model.py b/paddleslim/auto_compression/utils/prune_model.py index 426a1859..c0da14ca 100644 --- a/paddleslim/auto_compression/utils/prune_model.py +++ b/paddleslim/auto_compression/utils/prune_model.py @@ -5,7 +5,7 @@ import paddle import paddle.static as static from ...prune import Pruner from ...core import GraphWrapper -from .load_model import load_inference_model +from ...common.load_model import load_inference_model __all__ = ["get_sparse_model", "get_prune_model"] @@ -19,9 +19,10 @@ def get_sparse_model(executor, places, model_file, param_file, ratio, ratio(float): The ratio to prune the model. save_path(str): The save path of pruned model. """ - assert os.path.exists(model_file), f'{model_file} does not exist.' + assert os.path.exists(model_file), '{} does not exist.'.format(model_file) assert os.path.exists( - param_file) or param_file is None, f'{param_file} does not exist.' + param_file) or param_file is None, '{} does not exist.'.format( + param_file) paddle.enable_static() SKIP = ['image', 'feed', 'pool2d_0.tmp_0'] diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index e866790d..03825c8a 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -25,12 +25,16 @@ from .analyze_helper import VarCollector from . import wrapper_function from . import recover_program from . import patterns -from .convert_model import load_onnx_model +from .load_model import load_inference_model, get_model_dir, load_onnx_model +from .dataloader import wrap_dataloader, get_feed_vars +from .config_helper import load_config, save_config __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', - 'Server', 'Client', 'RLBaseController', 'VarCollector', 'load_onnx_model' + 'Server', 'Client', 'RLBaseController', 'VarCollector', 'load_onnx_model', + 'load_inference_model', 'get_model_dir', 'wrap_dataloader', 'get_feed_vars', + 'load_config', 'save_config' ] __all__ += wrapper_function.__all__ diff --git a/paddleslim/common/config_helper.py b/paddleslim/common/config_helper.py new file mode 100644 index 00000000..486fa9b4 --- /dev/null +++ b/paddleslim/common/config_helper.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml +import os + +__all__ = ['load_config', 'save_config'] + + +def print_arguments(args, level=0): + if level == 0: + print('----------- Running Arguments -----------') + for arg, value in sorted(args.items()): + if isinstance(value, dict): + print('\t' * level, '%s:' % arg) + print_arguments(value, level + 1) + else: + print('\t' * level, '%s: %s' % (arg, value)) + if level == 0: + print('------------------------------------------') + + +def load_config(config): + """Load configurations from yaml file into dict. + Fields validation is skipped for loading some custom information. + Args: + config(str): The path of configuration file. + Returns: + dict: A dict storing configuration information. + """ + if config is None: + return None + assert isinstance( + config, + str), f"config should be str but got type(config)={type(config)}" + assert os.path.exists(config) and os.path.isfile( + config), f"{config} not found or it is not a file." + with open(config) as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + print_arguments(cfg) + return cfg + + +def save_config(config, config_path): + """ + convert dict config to yaml. + """ + f = open(config_path, "w") + yaml.dump(config, f) + f.close() diff --git a/paddleslim/auto_compression/utils/dataloader.py b/paddleslim/common/dataloader.py similarity index 100% rename from paddleslim/auto_compression/utils/dataloader.py rename to paddleslim/common/dataloader.py diff --git a/paddleslim/common/convert_model.py b/paddleslim/common/load_model.py similarity index 58% rename from paddleslim/common/convert_model.py rename to paddleslim/common/load_model.py index 00e01820..81d50f73 100644 --- a/paddleslim/common/convert_model.py +++ b/paddleslim/common/load_model.py @@ -17,7 +17,6 @@ import logging import os import shutil import sys - import paddle from x2paddle.decoder.onnx_decoder import ONNXDecoder from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper @@ -27,7 +26,78 @@ from x2paddle.utils import ConverterCheck from . import get_logger _logger = get_logger(__name__, level=logging.INFO) -__all__ = ['load_onnx_model'] +__all__ = ['load_inference_model', 'get_model_dir', 'load_onnx_model'] + + +def load_inference_model(path_prefix, + executor, + model_filename=None, + params_filename=None): + # Load onnx model to Inference model. + if path_prefix.endswith('.onnx'): + inference_program, feed_target_names, fetch_targets = load_onnx_model( + path_prefix) + return [inference_program, feed_target_names, fetch_targets] + # Load Inference model. + # TODO: clean code + if model_filename is not None and model_filename.endswith('.pdmodel'): + model_name = '.'.join(model_filename.split('.')[:-1]) + assert os.path.exists( + os.path.join(path_prefix, model_name + '.pdmodel') + ), 'Please check {}, or fix model_filename parameter.'.format( + os.path.join(path_prefix, model_name + '.pdmodel')) + assert os.path.exists( + os.path.join(path_prefix, model_name + '.pdiparams') + ), 'Please check {}, or fix params_filename parameter.'.format( + os.path.join(path_prefix, model_name + '.pdiparams')) + model_path_prefix = os.path.join(path_prefix, model_name) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model( + path_prefix=model_path_prefix, executor=executor)) + elif model_filename is not None and params_filename is not None: + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model( + path_prefix=path_prefix, + executor=executor, + model_filename=model_filename, + params_filename=params_filename)) + else: + model_name = '.'.join(model_filename.split('.') + [:-1]) if model_filename is not None else 'model' + if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')): + model_path_prefix = os.path.join(path_prefix, model_name) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model( + path_prefix=model_path_prefix, executor=executor)) + else: + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model( + path_prefix=path_prefix, executor=executor)) + + return [inference_program, feed_target_names, fetch_targets] + + +def get_model_dir(model_dir, model_filename, params_filename): + if model_dir.endswith('.onnx'): + updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer' + else: + updated_model_dir = model_dir.rstrip('/') + + if model_filename == None: + updated_model_filename = 'model.pdmodel' + else: + updated_model_filename = model_filename + + if params_filename == None: + updated_params_filename = 'model.pdiparams' + else: + updated_params_filename = params_filename + + if params_filename is None and model_filename is not None: + raise NotImplementedError( + "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first." + ) + return updated_model_dir, updated_model_filename, updated_params_filename def load_onnx_model(model_path, disable_feedback=False): @@ -112,4 +182,4 @@ def load_onnx_model(model_path, disable_feedback=False): shutil.rmtree( os.path.join(inference_model_path, 'onnx2paddle_{}'.format( model_idx))) - return val_program, feed_target_names, fetch_targets + return val_program, feed_target_names, fetch_targets \ No newline at end of file diff --git a/paddleslim/quant/analysis.py b/paddleslim/quant/analysis.py new file mode 100644 index 00000000..1013c392 --- /dev/null +++ b/paddleslim/quant/analysis.py @@ -0,0 +1,331 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import pickle +import copy +import logging +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +import numpy as np + +import paddle +from paddle.fluid import core +from paddle.fluid import framework +from paddle.fluid.framework import IrGraph +from paddle.fluid.executor import global_scope +from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization +from paddle.fluid.contrib.slim.quantization.utils import _get_op_input_var_names, load_variable_data +from .quanter import quant_post +from ..core import GraphWrapper +from ..common import get_logger +from ..common import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ["AnalysisQuant"] + + +class AnalysisQuant(object): + def __init__( + self, + model_dir, + model_filename=None, + params_filename=None, + eval_function=None, + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + is_full_quantize=False, + batch_size=10, + batch_nums=10, + data_loader=None, + save_dir='analysis_results', + checkpoint_name='analysis_checkpoint.pkl', + num_histogram_plots=10, ): + """ + AnalysisQuant provides to analysis the sensitivity of each op in the model. + + Args: + model_dir(str): the path of fp32 model that will be quantized + model_filename(str): the model file name of the fp32 model + params_filename(str): the parameter file name of the fp32 model + eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model. (TODO: optional) + quantizable_op_type(list, optional): op types that can be quantized + batch_size(int, optional): the batch size of DataLoader, default is 10 + data_loader(Python Generator, Paddle.io.DataLoader, optional): the + Generator or Dataloader provides calibrate data, and it could + return a batch every time + save_dir(str, optional): the output dir that stores the analyzed information + checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing + num_histogram_plots: the number histogram plots you want to visilize, the plots will show in one PDF file in the save_dir + """ + self.model_dir = model_dir + self.model_filename = model_filename + self.params_filename = params_filename + self.batch_nums = batch_nums + self.quantizable_op_type = quantizable_op_type + self.weight_quantize_type = weight_quantize_type + self.activation_quantize_type = activation_quantize_type + self.is_full_quantize = is_full_quantize + self.histogram_bins = 1000 + self.save_dir = save_dir + self.eval_function = eval_function + self.quant_layer_names = [] + self.checkpoint_name = os.path.join(save_dir, checkpoint_name) + self.quant_layer_metrics = {} + self.batch_size = batch_size + self.batch_nums = batch_nums + self.num_histogram_plots = num_histogram_plots + + if not os.path.exists(self.save_dir): + os.mkdir(self.save_dir) + + devices = paddle.device.get_device().split(':')[0] + self.places = paddle.device._convert_to_place(devices) + executor = paddle.static.Executor(self.places) + + # load model + [program, self.feed_list, self.fetch_list]= load_inference_model( \ + model_dir, \ + executor=executor, \ + model_filename=model_filename, \ + params_filename=params_filename) + + # create data_loader + self.data_loader = wrap_dataloader(data_loader, self.feed_list) + + # evaluate before quant + # TODO: self.eval_function can be None + if self.eval_function is not None: + self.base_metric = self.eval_function( + executor, program, self.feed_list, self.fetch_list) + _logger.info('before quantized, the accuracy of the model is: {}'. + format(self.base_metric)) + + # quant and evaluate after quant (skip_list = None) + post_training_quantization = PostTrainingQuantization( + executor=executor, + data_loader=self.data_loader, + model_dir=self.model_dir, + model_filename=self.model_filename, + params_filename=self.params_filename, + batch_size=self.batch_size, + batch_nums=self.batch_nums, + algo='avg', # fastest + quantizable_op_type=self.quantizable_op_type, + weight_quantize_type=self.weight_quantize_type, + activation_quantize_type=self.activation_quantize_type, + is_full_quantize=self.is_full_quantize, + skip_tensor_list=None, ) + program = post_training_quantization.quantize() + self.quant_metric = self.eval_function(executor, program, + self.feed_list, self.fetch_list) + _logger.info('after quantized, the accuracy of the model is: {}'.format( + self.quant_metric)) + + # get quantized weight and act var name + self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name + self.quantized_act_var_name = post_training_quantization._quantized_act_var_name + executor.close() + + # load tobe_analyized_layer from checkpoint + self.load_checkpoint() + self.tobe_analyized_layer = self.quantized_weight_var_name - set( + list(self.quant_layer_metrics.keys())) + self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer)) + + def analysis(self): + self.compute_quant_sensitivity() + self.sensitivity_ranklist = sorted( + self.quant_layer_metrics, + key=self.quant_layer_metrics.get, + reverse=False) + + _logger.info('Finished computing the sensitivity of the model.') + for name in self.sensitivity_ranklist: + _logger.info("quant layer name: {}, eval metric: {}".format( + name, self.quant_layer_metrics[name])) + + analysis_file = os.path.join(self.save_dir, "analysis.txt") + with open(analysis_file, "w") as analysis_ret_f: + for name in self.sensitivity_ranklist: + analysis_ret_f.write( + "quant layer name: {}, eval metric: {}\n".format( + name, self.quant_layer_metrics[name])) + _logger.info('Analysis file is saved in {}'.format(analysis_file)) + self.calculate_histogram() + self.draw_pdf() + + def save_checkpoint(self): + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + with open(self.checkpoint_name, 'wb') as f: + pickle.dump(self.quant_layer_metrics, f) + _logger.info('save checkpoint to {}'.format(self.checkpoint_name)) + + def load_checkpoint(self): + if not os.path.exists(self.checkpoint_name): + return False + with open(self.checkpoint_name, 'rb') as f: + self.quant_layer_metrics = pickle.load(f) + _logger.info('load checkpoint from {}'.format(self.checkpoint_name)) + return True + + def compute_quant_sensitivity(self): + ''' + For each layer, quantize the weight op and evaluate the quantized model. + ''' + for i, layer_name in enumerate(self.tobe_analyized_layer): + _logger.info('checking {}/{} quant model: quant layer {}'.format( + i + 1, len(self.tobe_analyized_layer), layer_name)) + skip_list = copy.copy(list(self.quantized_weight_var_name)) + skip_list.remove(layer_name) + + executor = paddle.static.Executor(self.places) + post_training_quantization = PostTrainingQuantization( + executor=executor, + data_loader=self.data_loader, + model_dir=self.model_dir, + model_filename=self.model_filename, + params_filename=self.params_filename, + batch_size=self.batch_size, + batch_nums=self.batch_nums, + algo='avg', # fastest + quantizable_op_type=self.quantizable_op_type, + weight_quantize_type=self.weight_quantize_type, + activation_quantize_type=self.activation_quantize_type, + is_full_quantize=self.is_full_quantize, + skip_tensor_list=skip_list, ) + program = post_training_quantization.quantize() + + _logger.info('Evaluating...') + quant_metric = self.eval_function(executor, program, self.feed_list, + self.fetch_list) + executor.close() + _logger.info( + "quant layer name: {}, eval metric: {}, the loss caused by this layer: {}". + format(layer_name, quant_metric, self.base_metric - + quant_metric)) + self.quant_layer_metrics[layer_name] = quant_metric + self.save_checkpoint() + + def get_sensitive_ops_name(self, graph, program): + sensitive_weight_ops = self.sensitivity_ranklist[:self. + num_histogram_plots] + sensitive_act_ops = [] + persistable_var_names = [] + persistable_var_names = [] + for var in program.list_vars(): + if var.persistable: + persistable_var_names.append(var.name) + for op_name in sensitive_weight_ops: + for block_id in range(len(program.blocks)): + for op in program.blocks[block_id].ops: + var_name_list = _get_op_input_var_names(op) + if op_name in var_name_list: + for var_name in var_name_list: + if var_name not in persistable_var_names: + sensitive_act_ops.append(var_name) + return sensitive_act_ops, sensitive_weight_ops + + def calculate_histogram(self): + ''' + Sample histograms for the weight and corresponding act tensors + ''' + devices = paddle.device.get_device().split(':')[0] + places = paddle.device._convert_to_place(devices) + executor = paddle.static.Executor(places) + + [program, feed_list, fetch_list]= load_inference_model( \ + self.model_dir, \ + executor=executor, \ + model_filename=self.model_filename, \ + params_filename=self.params_filename) + + scope = global_scope() + + graph = IrGraph(core.Graph(program.desc), for_test=False) + self.sensitive_act_ops, self.sensitive_weight_ops = self.get_sensitive_ops_name( + graph, program) + + for var in program.list_vars(): + if var.name in self.quantized_act_var_name: + var.persistable = True + + batch_id = 0 + for data in self.data_loader(): + executor.run(program=program, + feed=data, + fetch_list=fetch_list, + return_numpy=False, + scope=scope) + batch_id += 1 + if batch_id >= self.batch_nums: + break + + self.weight_histogram = {} + self.act_histogram = {} + for var_name in self.sensitive_act_ops: + var_tensor = load_variable_data(scope, var_name) + var_tensor = np.array(var_tensor) + min_v = float(np.min(var_tensor)) + max_v = float(np.max(var_tensor)) + var_tensor = var_tensor.flatten() + _, hist_edges = np.histogram( + var_tensor.copy(), + bins=self.histogram_bins, + range=(min_v, max_v)) + self.act_histogram[var_name] = [var_tensor, hist_edges] + + for var_name in self.sensitive_weight_ops: + var_tensor = load_variable_data(scope, var_name) + var_tensor = np.array(var_tensor) + min_v = float(np.min(var_tensor)) + max_v = float(np.max(var_tensor)) + var_tensor = var_tensor.flatten() + _, hist_edges = np.histogram( + var_tensor.copy(), + bins=self.histogram_bins, + range=(min_v, max_v)) + self.weight_histogram[var_name] = [var_tensor, hist_edges] + + def draw_pdf(self): + pdf_path_a = os.path.join(self.save_dir, 'act_hist_result.pdf') + pdf_path_w = os.path.join(self.save_dir, 'weight_hist_result.pdf') + with PdfPages(pdf_path_a) as pdf: + for name in self.act_histogram: + plt.hist( + self.act_histogram[name][0], + bins=self.act_histogram[name][1]) + plt.xlabel(name) + plt.ylabel("frequency") + plt.title("Hist of variable {}".format(name)) + plt.show() + pdf.savefig() + plt.close() + with PdfPages(pdf_path_w) as pdf: + for name in self.weight_histogram: + plt.hist( + self.weight_histogram[name][0], + bins=self.weight_histogram[name][1]) + plt.xlabel(name) + plt.ylabel("frequency") + plt.title("Hist of variable {}".format(name)) + plt.show() + pdf.savefig() + plt.close() + _logger.info('Histogram plots are saved in {} and {}'.format( + pdf_path_a, pdf_path_w)) -- GitLab