未验证 提交 71e0acb1 编写于 作者: C Chang Xu 提交者: GitHub

Add New Func in Quant Analysis (#1439)

上级 9650ce42
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
## 1. 量化分析工具功能 ## 1. 量化分析工具功能
1. 遍历模型所有层,依次量化该层,计算量化后精度。为所有只量化一层的模型精度排序,可视化不适合量化的层,以供量化时可选择性跳过不适合量化的层。 1. 遍历模型所有层,依次量化该层,计算量化后精度。为所有只量化一层的模型精度排序,可视化不适合量化的层,以供量化时可选择性跳过不适合量化的层。
2. 可视化量化效果最好和最差的层的权重和激活分布图,以供分析模型量化效果的原因。 2. 可视化激活箱状图,以供分析每个可量化OP的激活分布对量化效果的影响。
3. 【敬请期待】输入预期精度,直接产出符合预期精度的量化模型。 3. 量化效果较好和较差的层的权重和激活直方分布图,以供分析其对量化效果的影响。
4. 输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisQuant 可传入参数解析 ## 2. paddleslim.quant.AnalysisQuant 可传入参数解析
```yaml ```yaml
...@@ -30,7 +31,37 @@ ptq_config ...@@ -30,7 +31,37 @@ ptq_config
## 3. 量化分析工具产出内容 ## 3. 量化分析工具的使用
1. 创建量化分析工具:
```
analyzer = AnalysisQuant(
model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
ptq_config=config['PTQ'])
```
2. 绘制所有可量化层的激活箱状图
```
analyzer.plot_activation_distribution()
```
以检测模型中的picodet-s为例,从以下激活箱状图(部分层)中可以发现,`conv2d_7.w_0``conv2d_9.w_0` 这两层的激活输入有大量离群点,会导致量化效果较差。
<p align="center">
<img src="./detection/images/act_distribution.png" width=849 hspace='10'/> <br />
</p>
3. 计算每层的量化敏感度并且绘制直方分布图
```
analyzer.compute_quant_sensitivity(plot_hist=True)
```
`plot_hist` 默认为True,如不需要获得量化效果较好和较差的层的权重和激活分布图,可设置为False。
量化分析工具会默认会产出以下目录: 量化分析工具会默认会产出以下目录:
``` ```
...@@ -41,9 +72,20 @@ analysis_results/ ...@@ -41,9 +72,20 @@ analysis_results/
├── worst_weight_hist_result.pdf ├── worst_weight_hist_result.pdf
├── worst_act_hist_result.pdf ├── worst_act_hist_result.pdf
``` ```
- 所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。 - 所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
- 通过设置参数`num_histogram_plots`,可选择绘出该数量个量化效果最好和最差层的weight和activation的直方分布图,将以PDF形式保存在 `./analysis_results` 文件夹下, 分别保存为 `best_weight_hist_result.pdf``best_act_hist_result.pdf``worst_weight_hist_result.pdf``worst_act_hist_result.pdf` 中以供对比分析。 - 通过设置参数`num_histogram_plots`,可选择绘出该数量个量化效果最好和最差层的weight和activation的直方分布图,将以PDF形式保存在 `./analysis_results` 文件夹下, 分别保存为 `best_weight_hist_result.pdf``best_act_hist_result.pdf``worst_weight_hist_result.pdf``worst_act_hist_result.pdf` 中以供对比分析。
以检测模型中的picodet-s为例,从`analysis.txt`可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失。这一现象符合对激活箱状图的观察。
<p align="center">
<img src="./detection/images/picodet_analysis.png" width=849 hspace='10'/> <br />
</p>
4. 直接产出符合预期精度的量化模型
```
analyzer.get_target_quant_model(target_metric)
```
## 3. 根据分析结果执行离线量化 ## 4. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。 执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。
...@@ -138,18 +138,35 @@ python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml ...@@ -138,18 +138,35 @@ python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml
python analysis.py --config_path=./configs/picodet_s_analysis.yaml python analysis.py --config_path=./configs/picodet_s_analysis.yaml
``` ```
如下图,经过量化分析之后,可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失,这些层均为主干网络中靠前部分的`depthwise_conv` 如下图,经过量化分析之后,可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失,这些层均为主干网络中靠前部分的`depthwise_conv`
<p align="center"> <p align="center">
<img src="./images/picodet_analysis.png" width=849 hspace='10'/> <br /> <img src="./images/picodet_analysis.png" width=849 hspace='10'/> <br />
</p> </p>
在保存的 `activation_distribution.pdf` 中,也可以发现以上这些层的 `activation` 存在较多离群点,导致量化效果较差。
<p align="center">
<img src="./images/act_distribution.png" width=849 hspace='10'/> <br />
</p>
经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [picodet_s_analyzed_ptq.yaml](./configs/picodet_s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升24.9个点。 经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [picodet_s_analyzed_ptq.yaml](./configs/picodet_s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升24.9个点。
```shell ```shell
python post_quant.py --config_path=./configs/picodet_s_analyzed_ptq.yaml --save_dir=./picodet_s_analyzed_ptq_out python post_quant.py --config_path=./configs/picodet_s_analyzed_ptq.yaml --save_dir=./picodet_s_analyzed_ptq_out
``` ```
如想分析之后直接产出符合目标精度的量化模型,可在 `picodet_s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `FastEvalDataset` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `FastEvalDataset` 字段删掉即可。
注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。
量化分析工具详细介绍见[量化分析工具介绍](../analysis.md)
## 4.预测部署 ## 4.预测部署
预测部署可参考[Detection模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/detection) 预测部署可参考[Detection模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/detection)
......
...@@ -126,12 +126,16 @@ def main(): ...@@ -126,12 +126,16 @@ def main():
config = load_config(FLAGS.config_path) config = load_config(FLAGS.config_path)
ptq_config = config['PTQ'] ptq_config = config['PTQ']
# val dataset is sufficient for PTQ
data_loader = create('EvalReader')(config['EvalDataset'], data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'], config['worker_num'],
return_list=True) return_list=True)
data_loader = reader_wrapper(data_loader, config['input_list']) ptq_data_loader = reader_wrapper(data_loader, config['input_list'])
dataset = config['EvalDataset'] # fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
dataset = config[
'FastEvalDataset'] if 'FastEvalDataset' in config else config[
'EvalDataset']
global val_loader global val_loader
_eval_batch_sampler = paddle.io.BatchSampler( _eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size']) dataset, batch_size=config['EvalReader']['batch_size'])
...@@ -162,10 +166,23 @@ def main(): ...@@ -162,10 +166,23 @@ def main():
model_filename=config["model_filename"], model_filename=config["model_filename"],
params_filename=config["params_filename"], params_filename=config["params_filename"],
eval_function=eval_function, eval_function=eval_function,
data_loader=data_loader, data_loader=ptq_data_loader,
save_dir=config['save_dir'], save_dir=config['save_dir'],
ptq_config=ptq_config) ptq_config=ptq_config)
analyzer.analysis()
# plot the boxplot of activations of quantizable weights
analyzer.plot_activation_distribution()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']:
if 'FastEvalDataset' in config:
# change fast_val_loader to full val_loader
val_loader = data_loader
# get the quantized model that satisfies target metric you set
analyzer.get_target_quant_model(target_metric=config['target_metric'])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,9 @@ params_filename: model.pdiparams ...@@ -5,6 +5,9 @@ params_filename: model.pdiparams
save_dir: ./analysis_results save_dir: ./analysis_results
metric: COCO metric: COCO
num_classes: 80 num_classes: 80
plot_hist: True
get_target_quant_model: False
target_metric: None
PTQ: PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"] quantizable_op_type: ["conv2d", "depthwise_conv2d"]
...@@ -15,18 +18,21 @@ PTQ: ...@@ -15,18 +18,21 @@ PTQ:
batch_nums: 10 batch_nums: 10
# Datset configuration # Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
image_dir: val2017 image_dir: val2017
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
# Small Dataset to accelerate analysis
# If not exist, delete the dict of FastEvalDataset
FastEvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/small_instances_val2017.json
dataset_dir: /dataset/coco/
eval_height: &eval_height 416 eval_height: &eval_height 416
eval_width: &eval_width 416 eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width] eval_size: &eval_size [*eval_height, *eval_width]
...@@ -41,7 +47,7 @@ EvalReader: ...@@ -41,7 +47,7 @@ EvalReader:
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_size: 32 batch_size: 1
input_list: ['image'] input_list: ['image']
arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE
model_dir: ./ppyoloe_crn_s_300e_coco model_dir: ./ppyoloe_crn_s_300e_coco
model_filename: model.pdmodel model_filename: model.pdmodel
params_filename: model.pdiparams params_filename: model.pdiparams
save_dir: ./analysis_results_ppyoloe save_dir: ./analysis_results_ppyoloe
metric: COCO metric: COCO
num_classes: 80 num_classes: 80
plot_hist: True
get_target_quant_model: False
target_metric: None
PTQ: PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"] quantizable_op_type: ["conv2d", "depthwise_conv2d"]
...@@ -14,19 +17,20 @@ PTQ: ...@@ -14,19 +17,20 @@ PTQ:
is_full_quantize: False is_full_quantize: False
batch_size: 32 batch_size: 32
batch_nums: 10 batch_nums: 10
# Datset configuration # Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
image_dir: val2017 image_dir: val2017
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: /paddle/dataset/coco/
# Small Dataset to accelerate analysis
# If not exist, delete the dict of FastEvalDataset
FastEvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/small_instances_val2017.json
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
worker_num: 0 worker_num: 0
...@@ -38,4 +42,4 @@ EvalReader: ...@@ -38,4 +42,4 @@ EvalReader:
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 32 batch_size: 1
\ No newline at end of file \ No newline at end of file
...@@ -38,8 +38,10 @@ ...@@ -38,8 +38,10 @@
#### 3.1 准备环境 #### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim > 2.3版本 - PaddleSlim > 2.3版本
- X2Paddle >= 1.3.9
- opencv-python - opencv-python
(1)安装paddlepaddle: (1)安装paddlepaddle:
```shell ```shell
# CPU # CPU
...@@ -139,10 +141,21 @@ python analysis.py --config_path=./configs/yolov6s_analysis.yaml ...@@ -139,10 +141,21 @@ python analysis.py --config_path=./configs/yolov6s_analysis.yaml
经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升9.4个点。 经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升9.4个点。
```shell ```shell
python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out
``` ```
如想分析之后直接产出符合目标精度的量化模型,可在 `yolov6s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `fast_val_anno_path` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `fast_val_anno_path` 设置为None即可。
注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。
量化分析工具详细介绍见[量化分析工具介绍](../analysis.md)
## 4.预测部署 ## 4.预测部署
预测部署可参考[YOLO系列模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/pytorch_yolo_series) 预测部署可参考[YOLO系列模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/pytorch_yolo_series)
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import argparse import argparse
import paddle import paddle
from tqdm import tqdm from tqdm import tqdm
from post_process import YOLOv6PostProcess, coco_metric from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset, COCOTrainDataset from dataset import COCOValDataset, COCOTrainDataset
from paddleslim.common import load_config, load_onnx_model from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant.analysis import AnalysisQuant from paddleslim.quant.analysis import AnalysisQuant
...@@ -53,7 +53,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -53,7 +53,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list, fetch_list=test_fetch_list,
return_numpy=False) return_numpy=False)
res = {} res = {}
postprocess = YOLOv6PostProcess( postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True) score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor']) res = postprocess(np.array(outs[0]), data_all['scale_factor'])
bboxes_list.append(res['bbox']) bboxes_list.append(res['bbox'])
...@@ -72,6 +72,8 @@ def main(): ...@@ -72,6 +72,8 @@ def main():
input_name = 'x2paddle_image_arrays' if config[ input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images' 'arch'] == 'YOLOv6' else 'x2paddle_images'
# val dataset is sufficient for PTQ
dataset = COCOTrainDataset( dataset = COCOTrainDataset(
dataset_dir=config['dataset_dir'], dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'], image_dir=config['val_image_dir'],
...@@ -81,10 +83,12 @@ def main(): ...@@ -81,10 +83,12 @@ def main():
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0) dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
global val_loader global val_loader
# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
dataset = COCOValDataset( dataset = COCOValDataset(
dataset_dir=config['dataset_dir'], dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'], image_dir=config['val_image_dir'],
anno_path=config['val_anno_path']) anno_path=config['fast_val_anno_path'] if
config['fast_val_anno_path'] is not None else config['val_anno_path'])
global anno_file global anno_file
anno_file = dataset.ann_file anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader( val_loader = paddle.io.DataLoader(
...@@ -101,7 +105,30 @@ def main(): ...@@ -101,7 +105,30 @@ def main():
data_loader=data_loader, data_loader=data_loader,
save_dir=config['save_dir'], save_dir=config['save_dir'],
ptq_config=ptq_config) ptq_config=ptq_config)
analyzer.analysis()
# plot the boxplot of activations of quantizable weights
analyzer.plot_activation_distribution()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']:
if config['fast_val_anno_path'] is not None:
# change fast_val_loader to full val_loader
dataset = COCOValDataset(
dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'],
anno_path=config['val_anno_path'])
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
dataset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=0)
# get the quantized model that satisfies target metric you set
analyzer.get_target_quant_model(config['target_metric'])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,6 +4,11 @@ save_dir: ./analysis_results ...@@ -4,6 +4,11 @@ save_dir: ./analysis_results
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
val_image_dir: val2017 val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis
fast_val_anno_path: annotations/small_instances_val2017.json # if not exist, please set None
get_target_quant_model: False
target_metric: None
plot_hist: True
PTQ: PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"] quantizable_op_type: ["conv2d", "depthwise_conv2d"]
......
arch: YOLOv7
model_dir: ./yolov7.onnx
save_dir: ./analysis_results
dataset_dir: /dataset/coco/
val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis
fast_val_anno_path: annotations/small_instances_val2017.json # if not exist, please set None
get_target_quant_model: False
target_metric: None
plot_hist: True
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
\ No newline at end of file
arch: YOLOv7 arch: YOLOv7
model_dir: ./yolov7s.onnx model_dir: ./yolov7.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
train_image_dir: train2017 train_image_dir: train2017
val_image_dir: val2017 val_image_dir: val2017
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
import numpy as np import numpy as np
import random
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
...@@ -105,7 +105,7 @@ class AnalysisQuant(object): ...@@ -105,7 +105,7 @@ class AnalysisQuant(object):
if self.eval_function is not None: if self.eval_function is not None:
self.base_metric = self.eval_function( self.base_metric = self.eval_function(
executor, program, self.feed_list, self.fetch_list) executor, program, self.feed_list, self.fetch_list)
_logger.info('before quantized, the accuracy of the model is: {}'. _logger.info('Before quantized, the accuracy of the model is: {}'.
format(self.base_metric)) format(self.base_metric))
# quant and evaluate after quant (skip_list = None) # quant and evaluate after quant (skip_list = None)
...@@ -121,7 +121,7 @@ class AnalysisQuant(object): ...@@ -121,7 +121,7 @@ class AnalysisQuant(object):
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
self.quant_metric = self.eval_function(executor, program, self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list) self.feed_list, self.fetch_list)
_logger.info('after quantized, the accuracy of the model is: {}'.format( _logger.info('After quantized, the accuracy of the model is: {}'.format(
self.quant_metric)) self.quant_metric))
# get quantized weight and act var name # get quantized weight and act var name
...@@ -135,8 +135,13 @@ class AnalysisQuant(object): ...@@ -135,8 +135,13 @@ class AnalysisQuant(object):
list(self.quant_layer_metrics.keys())) list(self.quant_layer_metrics.keys()))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer)) self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
def analysis(self): def compute_quant_sensitivity(self, plot_hist=True):
self.compute_quant_sensitivity() '''
compute the sensitivity of quantized layers by eval function
'''
assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
assert self.eval_function is not None, "When computing the sensitivity of quantized layers, the eval function is needed"
self.eval_quant_model()
self.sensitivity_ranklist = sorted( self.sensitivity_ranklist = sorted(
self.quant_layer_metrics, self.quant_layer_metrics,
key=self.quant_layer_metrics.get, key=self.quant_layer_metrics.get,
...@@ -154,7 +159,9 @@ class AnalysisQuant(object): ...@@ -154,7 +159,9 @@ class AnalysisQuant(object):
"quant layer name: {}, eval metric: {}\n".format( "quant layer name: {}, eval metric: {}\n".format(
name, self.quant_layer_metrics[name])) name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file)) _logger.info('Analysis file is saved in {}'.format(analysis_file))
self.calculate_histogram()
if plot_hist:
self.calculate_histogram()
def save_checkpoint(self): def save_checkpoint(self):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
...@@ -171,12 +178,76 @@ class AnalysisQuant(object): ...@@ -171,12 +178,76 @@ class AnalysisQuant(object):
_logger.info('load checkpoint from {}'.format(self.checkpoint_name)) _logger.info('load checkpoint from {}'.format(self.checkpoint_name))
return True return True
def compute_quant_sensitivity(self): def plot_activation_distribution(self, axis=None):
'''
Collect and plot the distribution of the activation of each weight layer.
'''
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(places)
[program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
weight_names = sorted(list(self.quantized_weight_var_name))
acts_weight_map = self.get_weight_act_map(program, weight_names,
persistable_var_names)
all_acts = list(acts_weight_map.keys())
all_weights = [acts_weight_map[act] for act in all_acts]
act_distribution = []
for var_name in all_acts:
var_tensor = load_variable_data(scope, var_name)
if axis is None:
var_tensor = var_tensor.flatten()
else:
var_tensor = var_tensor.reshape(
[-1, var_tensor.shape[axis]]).abs().max(axis=-1)
sample_num = len(var_tensor) if len(var_tensor) < 1000 else 1000
var_tensor = random.sample(list(var_tensor), sample_num)
act_distribution.append(var_tensor)
all_values = sum(act_distribution, [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(self.save_dir, 'activation_distribution.pdf')
with PdfPages(pdf_path) as pdf:
for i in range(0, len(act_distribution), 20):
r = i + 20 if i + 20 < len(act_distribution) else len(
act_distribution)
plt.boxplot(
act_distribution[i:r],
labels=all_weights[i:r],
showbox=True,
patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
plt.xlabel('Weight Name')
plt.ylabel("Activation Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Distribution plots is saved in {}'.format(pdf_path))
def eval_quant_model(self):
''' '''
For each layer, quantize the weight op and evaluate the quantized model. For each layer, quantize the weight op and evaluate the quantized model.
''' '''
for i, layer_name in enumerate(self.tobe_analyized_layer): for i, layer_name in enumerate(self.tobe_analyized_layer):
_logger.info('checking {}/{} quant model: quant layer {}'.format( _logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name)) i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.quantized_weight_var_name)) skip_list = copy.copy(list(self.quantized_weight_var_name))
skip_list.remove(layer_name) skip_list.remove(layer_name)
...@@ -198,15 +269,14 @@ class AnalysisQuant(object): ...@@ -198,15 +269,14 @@ class AnalysisQuant(object):
self.fetch_list) self.fetch_list)
executor.close() executor.close()
_logger.info( _logger.info(
"quant layer name: {}, eval metric: {}, the loss caused by this layer: {}". "Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name, quant_metric, self.base_metric - format(layer_name, quant_metric, self.base_metric -
quant_metric)) quant_metric))
self.quant_layer_metrics[layer_name] = quant_metric self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint() self.save_checkpoint()
def get_act_name_by_weight(self, program, weight_names, def get_weight_act_map(self, program, weight_names, persistable_var_names):
persistable_var_names): act_names = {}
act_ops_names = []
for op_name in weight_names: for op_name in weight_names:
for block_id in range(len(program.blocks)): for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops: for op in program.blocks[block_id].ops:
...@@ -214,8 +284,8 @@ class AnalysisQuant(object): ...@@ -214,8 +284,8 @@ class AnalysisQuant(object):
if op_name in var_name_list: if op_name in var_name_list:
for var_name in var_name_list: for var_name in var_name_list:
if var_name not in persistable_var_names: if var_name not in persistable_var_names:
act_ops_names.append(var_name) act_names[var_name] = op_name
return act_ops_names return act_names
def get_hist_ops_name(self, graph, program): def get_hist_ops_name(self, graph, program):
if self.num_histogram_plots <= 0: if self.num_histogram_plots <= 0:
...@@ -230,13 +300,13 @@ class AnalysisQuant(object): ...@@ -230,13 +300,13 @@ class AnalysisQuant(object):
if var.persistable: if var.persistable:
persistable_var_names.append(var.name) persistable_var_names.append(var.name)
best_act_ops = self.get_act_name_by_weight(program, best_weight_ops, best_acts = self.get_weight_act_map(program, best_weight_ops,
persistable_var_names) persistable_var_names)
worst_act_ops = self.get_act_name_by_weight(program, worst_weight_ops, worst_acts = self.get_weight_act_map(program, worst_weight_ops,
persistable_var_names) persistable_var_names)
return [best_weight_ops, best_act_ops, worst_weight_ops, worst_act_ops] return [best_weight_ops, best_acts, worst_weight_ops, worst_acts]
def collect_ops_histogram(self, scope, ops): def collect_tensors_histogram(self, scope, ops):
hist = {} hist = {}
for var_name in ops: for var_name in ops:
var_tensor = load_variable_data(scope, var_name) var_tensor = load_variable_data(scope, var_name)
...@@ -268,8 +338,8 @@ class AnalysisQuant(object): ...@@ -268,8 +338,8 @@ class AnalysisQuant(object):
scope = global_scope() scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False) graph = IrGraph(core.Graph(program.desc), for_test=False)
ops_tobe_draw_hist = self.get_hist_ops_name(graph, program) tensors_tobe_draw_hist = self.get_hist_ops_name(graph, program)
if not ops_tobe_draw_hist: if not tensors_tobe_draw_hist:
return return
for var in program.list_vars(): for var in program.list_vars():
...@@ -294,19 +364,72 @@ class AnalysisQuant(object): ...@@ -294,19 +364,72 @@ class AnalysisQuant(object):
'worst_weight_hist_result.pdf', 'worst_weight_hist_result.pdf',
'worst_act_hist_result.pdf', 'worst_act_hist_result.pdf',
] ]
for ops, save_pdf_name in zip(ops_tobe_draw_hist, pdf_names): for tensors, save_pdf_name in zip(tensors_tobe_draw_hist, pdf_names):
hist_data = self.collect_ops_histogram(scope, ops) if isinstance(tensors, list):
self.draw_pdf(hist_data, save_pdf_name) hist_data = self.collect_tensors_histogram(scope, tensors)
self.draw_hist_pdf(hist_data, save_pdf_name, None)
def draw_pdf(self, hist_data, save_pdf_name): else:
hist_data = self.collect_tensors_histogram(scope,
list(tensors.keys()))
self.draw_hist_pdf(hist_data, save_pdf_name, tensors)
def draw_hist_pdf(self, hist_data, save_pdf_name, weight_act_map):
pdf_path = os.path.join(self.save_dir, save_pdf_name) pdf_path = os.path.join(self.save_dir, save_pdf_name)
with PdfPages(pdf_path) as pdf: with PdfPages(pdf_path) as pdf:
for name in hist_data: for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1]) plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name) plt.xlabel(name)
plt.ylabel("frequency") plt.ylabel("Frequency")
plt.title("Hist of variable {}".format(name)) if 'act' in save_pdf_name:
plt.title("Hist of Activation {}/Input of Weight {}".format(
name, weight_act_map[name]))
else:
plt.title("Hist of Weight {}".format(name))
plt.show() plt.show()
pdf.savefig() pdf.savefig()
plt.close() plt.close()
_logger.info('Histogram plot is saved in {}'.format(pdf_path)) _logger.info('Histogram plot is saved in {}'.format(pdf_path))
def get_target_quant_model(self, target_metric):
_logger.info(
'Start to Find quant model that satisfies the target metric.')
_logger.info(
'Make sure that you are using full eval dataset to get target quantized model.'
)
skip_list = []
rank_list = copy.copy(self.sensitivity_ranklist)
while True:
skip_list.append(rank_list.pop(0))
_logger.info('Skip Ops: {}'.format(skip_list))
executor = paddle.static.Executor(self.places)
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_list,
**self.ptq_config)
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
_logger.info("Current eval metric: {}, the target metric: {}".
format(quant_metric, target_metric))
if quant_metric >= target_metric:
quantize_model_path = os.path.join(self.save_dir,
'target_quant_model')
_logger.info(
'The quantized model satisfies the target metric and is saved to {}'.
format(quantize_model_path))
post_training_quantization.save_quantized_model(
quantize_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
break
else:
_logger.info(
'The quantized model does not satisfy the target metric. Skip next Op...'
)
executor.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册