未验证 提交 3e972ce0 编写于 作者: C Chang Xu 提交者: GitHub

Update Analysis & Analysis Demo/README (#1347)

上级 07e06985
# 量化分析工具详细教程
## 1. 量化分析工具功能
1. 遍历模型所有层,依次量化该层,计算量化后精度。为所有只量化一层的模型精度排序,可视化不适合量化的层,以供量化时可选择性跳过不适合量化的层。
2. 可视化量化效果最好和最差的层的权重和激活分布图,以供分析模型量化效果的原因。
3. 【敬请期待】输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisQuant 可传入参数解析
```yaml
model_dir
model_filename: None
params_filename: None
eval_function: None
data_loader: None
save_dir: 'analysis_results'
checkpoint_name: 'analysis_checkpoint.pkl'
num_histogram_plots: 10
quantizable_op_type: ["conv2d", "depthwise_conv2d", "mul"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
```
- model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。
- model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。
- params_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。
- eval_function:目前不支持为None,需要传入自定义的验证函数。
- data_loader:模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader。
- save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`
- checkpoint_name:由于模型可能存在大量层需要分析,因此分析过程中会中间保存结果,如果程序中断会自动加载已经分析好的结果,默认为`analysis_checkpoint.pkl`
- num_histogram_plots:需要可视化的直方分布图数量。可视化量化效果最好和最坏的该数量个权重和激活的分布图。默认为10。若不需要可视化直方图,设置为0即可。
注:以下参数均为需要传入离线量化中的参数,保持默认不影响模型进行量化分析。
- quantizable_op_type:需要进行量化的OP类型。通过以下代码可输出所有支持量化的OP类型:
```
from paddleslim.quant.quanter import TRANSFORM_PASS_OP_TYPES,QUANT_DEQUANT_PASS_OP_TYPES
print(TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
```
- weight_quantize_type:参数量化方式。可选 'abs_max' , 'channel_wise_abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。 默认 'abs_max' 。
- activation_quantize_type:激活量化方式,可选 'abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。默认为 'moving_average_abs_max'。
- is_full_quantize:是否对模型进行全量化,默认为False。
- batch_size:模型校准使用的batch size大小,默认为10。
- batch_nums:模型校准时的总batch数量,默认为10。
## 3. 量化分析工具产出内容
量化分析工具会默认会产出以下目录:
```
analysis_results/
├── analysis.txt
├── best_weight_hist_result.pdf
├── best_act_hist_result.pdf
├── worst_weight_hist_result.pdf
├── worst_act_hist_result.pdf
```
- 所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
- 通过设置参数`num_histogram_plots`,可选择绘出该数量个量化效果最好和最差层的weight和activation的直方分布图,将以PDF形式保存在 `./analysis_results` 文件夹下, 分别保存为 `best_weight_hist_result.pdf``best_act_hist_result.pdf``worst_weight_hist_result.pdf``worst_act_hist_result.pdf` 中以供对比分析。
## 3. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
目录: 目录:
- [1.简介](#1简介) - [1.简介](#1简介)
- [2.Benchmark](#2Benchmark) - [2.Benchmark](#2Benchmark)
- [3.开始自动压缩](#离线量化流程) - [3.离线量化流程](#离线量化流程)
- [3.1 准备环境](#31-准备环境) - [3.1 准备环境](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集) - [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型) - [3.3 准备预测模型](#33-准备预测模型)
...@@ -54,7 +54,7 @@ pip install paddleslim ...@@ -54,7 +54,7 @@ pip install paddleslim
``` ```
#### 3.2 准备数据集 #### 3.2 准备数据集
本示例默认以COCO数据进行自动压缩实验,可以从[MS COCO官网](https://cocodataset.org)下载[Train](http://images.cocodataset.org/zips/train2017.zip)[Val](http://images.cocodataset.org/zips/val2017.zip)[annotation](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) 本示例默认以COCO数据进行自动压缩实验,可以从 [MS COCO官网](https://cocodataset.org) 下载 [Train](http://images.cocodataset.org/zips/train2017.zip)[Val](http://images.cocodataset.org/zips/val2017.zip)[annotation](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
目录格式如下: 目录格式如下:
``` ```
...@@ -75,28 +75,28 @@ dataset/coco/ ...@@ -75,28 +75,28 @@ dataset/coco/
#### 3.3 准备预测模型 #### 3.3 准备预测模型
(1)准备ONNX模型: (1)准备ONNX模型:
**yolov5**:可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) - YOLOv5:可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型,也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)
**yolov6**:可通过[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)的导出脚本来准备ONNX模型。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) - YOLOv6:可通过[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)的导出脚本来准备ONNX模型,也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx)
**yolov7**:可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) - YOLOv7:可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型,也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)
#### 3.4 离线量化并产出模型 #### 3.4 离线量化并产出模型
离线量化示例通过post_quant.py脚本启动,会使用接口```paddleslim.quant.quant_post_static```对模型进行量化。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行离线量化。具体运行命令为: 离线量化示例通过post_quant.py脚本启动,会使用接口```paddleslim.quant.quant_post_static```对模型进行量化。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行离线量化。具体运行命令为:
- yolov5 - YOLOv5
```shell ```shell
python post_quant.py --config_path=./configs/yolov5s_ptq.yaml --save_dir=./yolov5s_ptq_out python post_quant.py --config_path=./configs/yolov5s_ptq.yaml --save_dir=./yolov5s_ptq_out
``` ```
- yolov6 - YOLOv6
```shell ```shell
python post_quant.py --config_path=./configs/yolov6s_ptq.yaml --save_dir=./yolov6s_ptq_out python post_quant.py --config_path=./configs/yolov6s_ptq.yaml --save_dir=./yolov6s_ptq_out
``` ```
- yolov7 - YOLOv7
```shell ```shell
python post_quant.py --config_path=./configs/yolov7s_ptq.yaml --save_dir=./yolov7s_ptq_out python post_quant.py --config_path=./configs/yolov7s_ptq.yaml --save_dir=./yolov7s_ptq_out
...@@ -105,7 +105,8 @@ python post_quant.py --config_path=./configs/yolov7s_ptq.yaml --save_dir=./yolov ...@@ -105,7 +105,8 @@ python post_quant.py --config_path=./configs/yolov7s_ptq.yaml --save_dir=./yolov
#### 3.5 测试模型精度 #### 3.5 测试模型精度
修改[yolov5s_ptq.yaml](./configs/yolov5s_ptq.yaml)`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP: 修改 [yolov5s_ptq.yaml](./configs/yolov5s_ptq.yaml)`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/yolov5s_ptq.yaml python eval.py --config_path=./configs/yolov5s_ptq.yaml
...@@ -113,26 +114,37 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml ...@@ -113,26 +114,37 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml
#### 3.6 提高离线量化精度 #### 3.6 提高离线量化精度
本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。由于yolov6离线量化效果较差,以yolov6为例,量化分析工具具体使用方法如下: 本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。
由于YOLOv6离线量化效果较差,以YOLOv6为例,量化分析工具具体使用方法如下:
```shell ```shell
python analysis.py --config_path=./configs/yolov6s_analysis.yaml python analysis.py --config_path=./configs/yolov6s_analysis.yaml
``` ```
经过分析之后,会产出模型每一层量化后的精度,和较差层的weight和activation的分布图。在进行离线量化时,可以跳过这些导致精度下降较多的层,如yolov6中,经过分析后,可跳过`conv2d_2.w_0``conv2d_11.w_0``conv2d_15.w_0``conv2d_46.w_0``conv2d_49.w_0`,可使用[yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这五层后,离线量化精度上升9.4个点。 如下图,经过量化分析之后,可以发现`conv2d_2.w_0``conv2d_11.w_0``conv2d_15.w_0``conv2d_46.w_0``conv2d_49.w_0` 这些层会导致较大的精度损失。
<p align="center">
<img src="./images/sensitivity_rank.png" width=849 hspace='10'/> <br />
</p>
```shell
python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out
```
注:
- 分析后,每层量化的精度会默认保存在`./analysis_results/analysis.txt`,直方分布图会默认保存在`./analysis_results/act_hist_result.pdf``./analysis_results/weight_hist_result.pdf` 对比权重直方分布图后,可以发现量化损失较小的层数值分布相对平稳,数值处于-0.25到0.25之间,而量化损失较大的层数值分布非常极端,绝大部分值趋近于0,且数值处于-0.1到0.1之间,尽管看上去都是正太分布,但大量值为0是不利于量化统计scale值的
<p align="center"> <p align="center">
<img src="./images/sensitivity_rank.png" width=849 hspace='10'/> <br /> <img src="./images/hist_compare.png" width=849 hspace='10'/> <br />
</p> </p>
经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升9.4个点。
```shell
python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out
```
## 4.预测部署 ## 4.预测部署
## 5.FAQ ## 5.FAQ
- 如果想对模型进行自动压缩,可进入[YOLO系列模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/pytorch_yolo_series)中进行实验。
...@@ -69,10 +69,13 @@ def main(): ...@@ -69,10 +69,13 @@ def main():
global config global config
config = load_config(FLAGS.config_path) config = load_config(FLAGS.config_path)
input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images'
dataset = COCOTrainDataset( dataset = COCOTrainDataset(
dataset_dir=config['dataset_dir'], dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'], image_dir=config['val_image_dir'],
anno_path=config['val_anno_path']) anno_path=config['val_anno_path'],
input_name=input_name)
data_loader = paddle.io.DataLoader( data_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0) dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
......
arch: YOLOv5
model_dir: ./yolov5s.onnx model_dir: ./yolov5s.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
train_image_dir: train2017 train_image_dir: train2017
......
arch: YOLOv6
model_dir: ./yolov6s.onnx model_dir: ./yolov6s.onnx
save_dir: ./analysis_results save_dir: ./analysis_results
quantizable_op_type: ["conv2d", "depthwise_conv2d"] quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'channel_wise_abs_max' weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max' activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False is_full_quantize: False
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
......
arch: YOLOv6
model_dir: ./yolov6s.onnx model_dir: ./yolov6s.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
train_image_dir: train2017 train_image_dir: train2017
......
arch: YOLOv6
model_dir: ./yolov6s.onnx model_dir: ./yolov6s.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
train_image_dir: train2017 train_image_dir: train2017
......
arch: YOLOv7
model_dir: ./yolov7s.onnx model_dir: ./yolov7s.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
train_image_dir: train2017 train_image_dir: train2017
......
...@@ -10,10 +10,12 @@ class COCOValDataset(paddle.io.Dataset): ...@@ -10,10 +10,12 @@ class COCOValDataset(paddle.io.Dataset):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
img_size=[640, 640]): img_size=[640, 640],
input_name='x2paddle_images'):
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.image_dir = image_dir self.image_dir = image_dir
self.img_size = img_size self.img_size = img_size
self.input_name = input_name
self.ann_file = os.path.join(dataset_dir, anno_path) self.ann_file = os.path.join(dataset_dir, anno_path)
self.coco = COCO(self.ann_file) self.coco = COCO(self.ann_file)
ori_ids = list(sorted(self.coco.imgs.keys())) ori_ids = list(sorted(self.coco.imgs.keys()))
...@@ -110,4 +112,4 @@ class COCOTrainDataset(COCOValDataset): ...@@ -110,4 +112,4 @@ class COCOTrainDataset(COCOValDataset):
img_id = self.ids[idx] img_id = self.ids[idx]
img = self._get_img_data_from_img_id(img_id) img = self._get_img_data_from_img_id(img_id)
img, scale_factor = self.image_preprocess(img, self.img_size) img, scale_factor = self.image_preprocess(img, self.img_size)
return {'x2paddle_image_arrays': img} return {self.input_name: img}
\ No newline at end of file
...@@ -50,10 +50,13 @@ def main(): ...@@ -50,10 +50,13 @@ def main():
global config global config
config = load_config(FLAGS.config_path) config = load_config(FLAGS.config_path)
input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images'
dataset = COCOTrainDataset( dataset = COCOTrainDataset(
dataset_dir=config['dataset_dir'], dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'], image_dir=config['val_image_dir'],
anno_path=config['val_anno_path']) anno_path=config['val_anno_path'],
input_name=input_name)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0) dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
...@@ -79,7 +82,9 @@ def main(): ...@@ -79,7 +82,9 @@ def main():
hist_percent=0.999, hist_percent=0.999,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
onnx_format=True) onnx_format=True,
skip_tensor_list=config['skip_tensor_list']
if 'skip_tensor_list' in config else None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -45,33 +45,41 @@ class AnalysisQuant(object): ...@@ -45,33 +45,41 @@ class AnalysisQuant(object):
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
eval_function=None, eval_function=None,
data_loader=None,
save_dir='analysis_results',
checkpoint_name='analysis_checkpoint.pkl',
num_histogram_plots=10,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max', activation_quantize_type='moving_average_abs_max',
is_full_quantize=False, is_full_quantize=False,
batch_size=10, batch_size=10,
batch_nums=10, batch_nums=10, ):
data_loader=None,
save_dir='analysis_results',
checkpoint_name='analysis_checkpoint.pkl',
num_histogram_plots=10, ):
""" """
AnalysisQuant provides to analysis the sensitivity of each op in the model. AnalysisQuant provides to analysis the sensitivity of each op in the model.
Args: Args:
model_dir(str): the path of fp32 model that will be quantized model_dir(str): the path of fp32 model that will be quantized, it can also be '.onnx'
model_filename(str): the model file name of the fp32 model model_filename(str, optional): the model file name of the fp32 model
params_filename(str): the parameter file name of the fp32 model params_filename(str, optional): the parameter file name of the fp32 model
eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model. (TODO: optional) eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model. (TODO: optional)
quantizable_op_type(list, optional): op types that can be quantized
batch_size(int, optional): the batch size of DataLoader, default is 10
data_loader(Python Generator, Paddle.io.DataLoader, optional): the data_loader(Python Generator, Paddle.io.DataLoader, optional): the
Generator or Dataloader provides calibrate data, and it could Generator or Dataloader provides calibrate data, and it could
return a batch every time return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information save_dir(str, optional): the output dir that stores the analyzed information
checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing
num_histogram_plots: the number histogram plots you want to visilize, the plots will show in one PDF file in the save_dir num_histogram_plots: the number histogram plots you want to visilize, the plots will show in four PDF files for both best and worst and for both weight and act ops in the save_dir
quantizable_op_type(list): op types that can be quantized
weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'
activation_quantize_type(str): quantization type for activation, now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type. If False, only apply quantization to the input quantizable_op_type. Default is False.
batch_size(int, optional): the batch size of DataLoader, default is 10
batch_nums(int, optional): the number of calibrate data is 'batch_size*batch_nums'
""" """
if model_filename is None:
model_filename = 'model.pdmodel'
if params_filename is None:
params_filename = 'model.pdiparams'
self.model_dir = model_dir self.model_dir = model_dir
self.model_filename = model_filename self.model_filename = model_filename
self.params_filename = params_filename self.params_filename = params_filename
...@@ -99,10 +107,10 @@ class AnalysisQuant(object): ...@@ -99,10 +107,10 @@ class AnalysisQuant(object):
# load model # load model
[program, self.feed_list, self.fetch_list]= load_inference_model( \ [program, self.feed_list, self.fetch_list]= load_inference_model( \
model_dir, \ self.model_dir, \
executor=executor, \ executor=executor, \
model_filename=model_filename, \ model_filename=self.model_filename, \
params_filename=params_filename) params_filename=self.params_filename)
# create data_loader # create data_loader
self.data_loader = wrap_dataloader(data_loader, self.feed_list) self.data_loader = wrap_dataloader(data_loader, self.feed_list)
...@@ -167,7 +175,6 @@ class AnalysisQuant(object): ...@@ -167,7 +175,6 @@ class AnalysisQuant(object):
name, self.quant_layer_metrics[name])) name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file)) _logger.info('Analysis file is saved in {}'.format(analysis_file))
self.calculate_histogram() self.calculate_histogram()
self.draw_pdf()
def save_checkpoint(self): def save_checkpoint(self):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
...@@ -222,24 +229,52 @@ class AnalysisQuant(object): ...@@ -222,24 +229,52 @@ class AnalysisQuant(object):
self.quant_layer_metrics[layer_name] = quant_metric self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint() self.save_checkpoint()
def get_sensitive_ops_name(self, graph, program): def get_act_name_by_weight(self, program, weight_names,
sensitive_weight_ops = self.sensitivity_ranklist[:self. persistable_var_names):
num_histogram_plots] act_ops_names = []
sensitive_act_ops = [] for op_name in weight_names:
persistable_var_names = []
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op_name in sensitive_weight_ops:
for block_id in range(len(program.blocks)): for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops: for op in program.blocks[block_id].ops:
var_name_list = _get_op_input_var_names(op) var_name_list = _get_op_input_var_names(op)
if op_name in var_name_list: if op_name in var_name_list:
for var_name in var_name_list: for var_name in var_name_list:
if var_name not in persistable_var_names: if var_name not in persistable_var_names:
sensitive_act_ops.append(var_name) act_ops_names.append(var_name)
return sensitive_act_ops, sensitive_weight_ops return act_ops_names
def get_hist_ops_name(self, graph, program):
if self.num_histogram_plots <= 0:
return []
best_weight_ops = self.sensitivity_ranklist[::-1][:self.
num_histogram_plots]
worst_weight_ops = self.sensitivity_ranklist[:self.num_histogram_plots]
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
best_act_ops = self.get_act_name_by_weight(program, best_weight_ops,
persistable_var_names)
worst_act_ops = self.get_act_name_by_weight(program, worst_weight_ops,
persistable_var_names)
return [best_weight_ops, best_act_ops, worst_weight_ops, worst_act_ops]
def collect_ops_histogram(self, scope, ops):
hist = {}
for var_name in ops:
var_tensor = load_variable_data(scope, var_name)
var_tensor = np.array(var_tensor)
min_v = float(np.min(var_tensor))
max_v = float(np.max(var_tensor))
var_tensor = var_tensor.flatten()
_, hist_edges = np.histogram(
var_tensor.copy(),
bins=self.histogram_bins,
range=(min_v, max_v))
hist[var_name] = [var_tensor, hist_edges]
return hist
def calculate_histogram(self): def calculate_histogram(self):
''' '''
...@@ -258,13 +293,15 @@ class AnalysisQuant(object): ...@@ -258,13 +293,15 @@ class AnalysisQuant(object):
scope = global_scope() scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False) graph = IrGraph(core.Graph(program.desc), for_test=False)
self.sensitive_act_ops, self.sensitive_weight_ops = self.get_sensitive_ops_name( ops_tobe_draw_hist = self.get_hist_ops_name(graph, program)
graph, program) if not ops_tobe_draw_hist:
return
for var in program.list_vars(): for var in program.list_vars():
if var.name in self.quantized_act_var_name: if var.name in self.quantized_act_var_name:
var.persistable = True var.persistable = True
# sample before collect histogram
batch_id = 0 batch_id = 0
for data in self.data_loader(): for data in self.data_loader():
executor.run(program=program, executor.run(program=program,
...@@ -276,56 +313,25 @@ class AnalysisQuant(object): ...@@ -276,56 +313,25 @@ class AnalysisQuant(object):
if batch_id >= self.batch_nums: if batch_id >= self.batch_nums:
break break
self.weight_histogram = {} pdf_names = [
self.act_histogram = {} 'best_weight_hist_result.pdf',
for var_name in self.sensitive_act_ops: 'best_act_hist_result.pdf',
var_tensor = load_variable_data(scope, var_name) 'worst_weight_hist_result.pdf',
var_tensor = np.array(var_tensor) 'worst_act_hist_result.pdf',
min_v = float(np.min(var_tensor)) ]
max_v = float(np.max(var_tensor)) for ops, save_pdf_name in zip(ops_tobe_draw_hist, pdf_names):
var_tensor = var_tensor.flatten() hist_data = self.collect_ops_histogram(scope, ops)
_, hist_edges = np.histogram( self.draw_pdf(hist_data, save_pdf_name)
var_tensor.copy(),
bins=self.histogram_bins, def draw_pdf(self, hist_data, save_pdf_name):
range=(min_v, max_v)) pdf_path = os.path.join(self.save_dir, save_pdf_name)
self.act_histogram[var_name] = [var_tensor, hist_edges] with PdfPages(pdf_path) as pdf:
for name in hist_data:
for var_name in self.sensitive_weight_ops: plt.hist(hist_data[name][0], bins=hist_data[name][1])
var_tensor = load_variable_data(scope, var_name)
var_tensor = np.array(var_tensor)
min_v = float(np.min(var_tensor))
max_v = float(np.max(var_tensor))
var_tensor = var_tensor.flatten()
_, hist_edges = np.histogram(
var_tensor.copy(),
bins=self.histogram_bins,
range=(min_v, max_v))
self.weight_histogram[var_name] = [var_tensor, hist_edges]
def draw_pdf(self):
pdf_path_a = os.path.join(self.save_dir, 'act_hist_result.pdf')
pdf_path_w = os.path.join(self.save_dir, 'weight_hist_result.pdf')
with PdfPages(pdf_path_a) as pdf:
for name in self.act_histogram:
plt.hist(
self.act_histogram[name][0],
bins=self.act_histogram[name][1])
plt.xlabel(name)
plt.ylabel("frequency")
plt.title("Hist of variable {}".format(name))
plt.show()
pdf.savefig()
plt.close()
with PdfPages(pdf_path_w) as pdf:
for name in self.weight_histogram:
plt.hist(
self.weight_histogram[name][0],
bins=self.weight_histogram[name][1])
plt.xlabel(name) plt.xlabel(name)
plt.ylabel("frequency") plt.ylabel("frequency")
plt.title("Hist of variable {}".format(name)) plt.title("Hist of variable {}".format(name))
plt.show() plt.show()
pdf.savefig() pdf.savefig()
plt.close() plt.close()
_logger.info('Histogram plots are saved in {} and {}'.format( _logger.info('Histogram plot is saved in {}'.format(pdf_path))
pdf_path_a, pdf_path_w))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册