未验证 提交 18d3524a 编写于 作者: C Chang Xu 提交者: GitHub

Update Quant Analysis (#1463)

上级 12ec1002
# 量化分析工具详细教程 # 量化分析工具详细教程
## 1. 量化分析工具功能 ## 1. 量化分析工具功能
1. 遍历模型所有层,依次量化该层,计算量化后精度。为所有只量化一层的模型精度排序,可视化不适合量化的层,以供量化时可选择性跳过不适合量化的层。 1. statistical_analyse:
2. 可视化激活箱状图,以供分析每个可量化OP的激活分布对量化效果的影响。 - 可视化激活和权重箱状图。箱状图可发现是否出现离群点。
3. 量化效果较好和较差的层的权重和激活直方分布图,以供分析其对量化效果的影响。 - 可视化权重和激活直方分布图。直方分布图可观察更具体的数值分布。
4. 输入预期精度,直接产出符合预期精度的量化模型。 - 提供量化前后权重和激活的具体数据信息,包括min,max,mean,std等
2. metric_error_analyse:
- 遍历量化模型的每层,并计算量化后精度。该功能可以定位具体某层导致的量化损失。
3. get_target_quant_model:
- 输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisQuant 可传入参数解析 ## 2. paddleslim.quant.AnalysisQuant 可传入参数解析
```yaml ```yaml
...@@ -14,25 +21,23 @@ params_filename: None ...@@ -14,25 +21,23 @@ params_filename: None
eval_function: None eval_function: None
data_loader: None data_loader: None
save_dir: 'analysis_results' save_dir: 'analysis_results'
checkpoint_name: 'analysis_checkpoint.pkl' resume: False
num_histogram_plots: 10
ptq_config ptq_config
``` ```
- model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。 - model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。
- model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。 - model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。
- params_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。 - params_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。
- eval_function:目前不支持为None,需要传入自定义的验证函数。 - eval_function:若需要验证精度,需要传入自定义的验证函数。
- data_loader:模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader。 - data_loader:模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader。
- save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results` - save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`
- checkpoint_name:由于模型可能存在大量层需要分析,因此分析过程中会中间保存结果,如果程序中断会自动加载已经分析好的结果,默认为`analysis_checkpoint.pkl` - resume:是否加载中间分析文件
- num_histogram_plots:需要可视化的直方分布图数量。可视化量化效果最好和最坏的该数量个权重和激活的分布图。默认为10。若不需要可视化直方图,设置为0即可。
- ptq_config:可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) - ptq_config:可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post)
## 3. 量化分析工具的使用 ## 3. 量化分析工具的使用
1. 创建量化分析工具 **创建量化分析工具**
``` ```
analyzer = AnalysisQuant( analyzer = AnalysisQuant(
model_dir=config["model_dir"], model_dir=config["model_dir"],
...@@ -44,45 +49,47 @@ analyzer = AnalysisQuant( ...@@ -44,45 +49,47 @@ analyzer = AnalysisQuant(
ptq_config=config['PTQ']) ptq_config=config['PTQ'])
``` ```
2. 绘制所有可量化层的激活箱状图 **统计分析**
``` ```
analyzer.plot_activation_distribution() analyzer.statistical_analyse()
``` ```
以检测模型中的picodet-s为例,从以下激活箱状图(部分层)中可以发现,`conv2d_7.w_0``conv2d_9.w_0` 这两层的激活输入有大量离群点,会导致量化效果较差。 调用该接口,会统计量化前和量化后每一个可量化权重和其对应激活的数据。只使用该接口可以不输入Eval Function,但需要输入DataLoader,少量数据即可。会产出以下文件:
- `fp_activation_boxplot.pdf`:量化前Float数据类型的模型激活箱状图
<p align="center"> - `fp_weight_boxplot.pdf`:量化前Float数据类型的模型权重箱状图
<img src="./detection/images/act_distribution.png" width=849 hspace='10'/> <br /> - `quantized_activation_boxplot.pdf`:量化后INT数据类型的模型激活箱状图
</p> - `quantized_weight_boxplot.pdf`:量化后INT数据类型的模型权重箱状图
- `fp_activation_histplot.pdf`:量化前Float数据类型的模型激活直方图
3. 计算每层的量化敏感度并且绘制直方分布图 - `fp_weight_histplot.pdf`:量化前Float数据类型的模型权重直方图
- `quantized_activation_histplot.pdf`:量化后INT数据类型的模型激活直方图
- `quantized_weight_histplot.pdf`:量化后INT数据类型的模型权重直方图
- `statistic.csv`:量化前后权重和激活的具体数据信息,表格中会保存的信息有:
- Var Name: Variable的名称
- Var Type:Variable的类型,Weight或Activation
- Corresponding Weight Name:如果为Activation,其对应的Weight名称
- FP32 Min:量化前Float数据类型的最小值
- FP32 Max:量化前Float数据类型的最大值
- FP32 Mean:量化前Float数据类型的平均值
- FP32 Std:量化前Float数据类型的方差值
- Quantized Min:量化后INT数据类型的最小值
- Quantized Max:量化后INT数据类型的最大值
- Quantized Mean:量化后INT数据类型的平均值
- Quantized Std:量化后INT数据类型的方差值
- Diff Min:量化前后该Variable的相差的最小值
- Diff Max:量化前后该Variable的相差的最大值
- Diff Mean:量化前后该Variable的相差的平均值
- Diff Std:量化前后该Variable的相差的方差值
**精度误差分析**
``` ```
analyzer.compute_quant_sensitivity(plot_hist=True) analyzer.metric_error_analyse()
``` ```
`plot_hist` 默认为True,如不需要获得量化效果较好和较差的层的权重和激活分布图,可设置为False。 调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
量化分析工具会默认会产出以下目录:
```
analysis_results/
├── analysis.txt
├── best_weight_hist_result.pdf
├── best_act_hist_result.pdf
├── worst_weight_hist_result.pdf
├── worst_act_hist_result.pdf
```
- 所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
- 通过设置参数`num_histogram_plots`,可选择绘出该数量个量化效果最好和最差层的weight和activation的直方分布图,将以PDF形式保存在 `./analysis_results` 文件夹下, 分别保存为 `best_weight_hist_result.pdf``best_act_hist_result.pdf``worst_weight_hist_result.pdf``worst_act_hist_result.pdf` 中以供对比分析。
以检测模型中的picodet-s为例,从`analysis.txt`可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失。这一现象符合对激活箱状图的观察。
<p align="center">
<img src="./detection/images/picodet_analysis.png" width=849 hspace='10'/> <br />
</p>
4. 直接产出符合预期精度的量化模型 **直接产出符合预期精度的量化模型**
``` ```
analyzer.get_target_quant_model(target_metric) analyzer.get_target_quant_model(target_metric)
``` ```
......
...@@ -130,7 +130,8 @@ python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml ...@@ -130,7 +130,8 @@ python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml
- 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。 - 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。
#### 3.6 提高离线量化精度 #### 3.6 提高离线量化精度
本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。 本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisQuant```详解见[AnalysisQuant.md](../../../../docs/zh_cn/tutorials/quant/AnalysisQuant.md)
经过多个实验,包括尝试多种激活算法(avg,KL等)、weight的量化方式(abs_max,channel_wise_abs_max),对PicoDet-s进行离线量化后精度均为0,以PicoDet-s为例,量化分析工具具体使用方法如下: 经过多个实验,包括尝试多种激活算法(avg,KL等)、weight的量化方式(abs_max,channel_wise_abs_max),对PicoDet-s进行离线量化后精度均为0,以PicoDet-s为例,量化分析工具具体使用方法如下:
......
...@@ -168,14 +168,11 @@ def main(): ...@@ -168,14 +168,11 @@ def main():
eval_function=eval_function, eval_function=eval_function,
data_loader=ptq_data_loader, data_loader=ptq_data_loader,
save_dir=config['save_dir'], save_dir=config['save_dir'],
ptq_config=ptq_config) ptq_config=ptq_config,
resume=True, )
# plot the boxplot of activations of quantizable weights analyzer.statistical_analyse()
analyzer.plot_activation_distribution() analyzer.metric_error_analyse()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']: if config['get_target_quant_model']:
if 'FastEvalDataset' in config: if 'FastEvalDataset' in config:
......
...@@ -116,7 +116,7 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml ...@@ -116,7 +116,7 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml
#### 3.6 提高离线量化精度 #### 3.6 提高离线量化精度
本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。 本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisQuant```详解见[AnalysisQuant.md](../../../../docs/zh_cn/tutorials/quant/AnalysisQuant.md)
由于YOLOv6离线量化效果较差,以YOLOv6为例,量化分析工具具体使用方法如下: 由于YOLOv6离线量化效果较差,以YOLOv6为例,量化分析工具具体使用方法如下:
...@@ -148,6 +148,8 @@ python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_di ...@@ -148,6 +148,8 @@ python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_di
如想分析之后直接产出符合目标精度的量化模型,可在 `yolov6s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。 如想分析之后直接产出符合目标精度的量化模型,可在 `yolov6s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程** **加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `fast_val_anno_path` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `fast_val_anno_path` 设置为None即可。 使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `fast_val_anno_path` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `fast_val_anno_path` 设置为None即可。
......
...@@ -113,12 +113,8 @@ def main(): ...@@ -113,12 +113,8 @@ def main():
resume=FLAGS.resume, resume=FLAGS.resume,
ptq_config=ptq_config) ptq_config=ptq_config)
# plot the boxplot of activations of quantizable weights analyzer.statistical_analyse()
analyzer.plot_activation_distribution() analyzer.metric_error_analyse()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']: if config['get_target_quant_model']:
if config['fast_val_anno_path'] is not None: if config['fast_val_anno_path'] is not None:
......
...@@ -19,6 +19,7 @@ import copy ...@@ -19,6 +19,7 @@ import copy
import logging import logging
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
import csv
import numpy as np import numpy as np
import random import random
import tempfile import tempfile
...@@ -28,7 +29,7 @@ from paddle.fluid import framework ...@@ -28,7 +29,7 @@ from paddle.fluid import framework
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization.utils import _get_op_input_var_names, load_variable_data from paddle.fluid.contrib.slim.quantization.utils import _get_op_input_var_names, _get_op_output_var_names, load_variable_data
from .quanter import quant_post from .quanter import quant_post
from ..core import GraphWrapper from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
...@@ -47,7 +48,6 @@ class AnalysisQuant(object): ...@@ -47,7 +48,6 @@ class AnalysisQuant(object):
eval_function=None, eval_function=None,
data_loader=None, data_loader=None,
save_dir='analysis_results', save_dir='analysis_results',
num_histogram_plots=10,
resume=False, resume=False,
ptq_config=None): ptq_config=None):
""" """
...@@ -79,7 +79,6 @@ class AnalysisQuant(object): ...@@ -79,7 +79,6 @@ class AnalysisQuant(object):
self.quant_layer_names = [] self.quant_layer_names = []
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl') self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.quant_layer_metrics = {} self.quant_layer_metrics = {}
self.num_histogram_plots = num_histogram_plots
self.ptq_config = ptq_config self.ptq_config = ptq_config
self.batch_nums = ptq_config[ self.batch_nums = ptq_config[
'batch_nums'] if 'batch_nums' in ptq_config else 10 'batch_nums'] if 'batch_nums' in ptq_config else 10
...@@ -112,25 +111,12 @@ class AnalysisQuant(object): ...@@ -112,25 +111,12 @@ class AnalysisQuant(object):
# create data_loader # create data_loader
self.data_loader = wrap_dataloader(data_loader, self.feed_list) self.data_loader = wrap_dataloader(data_loader, self.feed_list)
# evaluate before quant # quant model to get quantizable ops
# TODO: self.eval_function can be None post_training_quantization = self.create_ptq(executor, None, 'avg')
if self.eval_function is not None:
self.base_metric = self.eval_function( _logger.info('Run PTQ before analysis.')
executor, program, self.feed_list, self.fetch_list)
_logger.info('Before quantized, the accuracy of the model is: {}'.
format(self.base_metric))
# quant and evaluate after quant (skip_list = None)
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=None,
onnx_format=self.onnx_format,
**self.ptq_config)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
if self.onnx_format: if self.onnx_format:
post_training_quantization.save_quantized_model( post_training_quantization.save_quantized_model(
self.temp_save_path, self.temp_save_path,
...@@ -141,16 +127,14 @@ class AnalysisQuant(object): ...@@ -141,16 +127,14 @@ class AnalysisQuant(object):
executor, executor,
model_filename='model.pdmodel', model_filename='model.pdmodel',
params_filename='model.pdiparams') params_filename='model.pdiparams')
self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list)
_logger.info('After quantized, the accuracy of the model is: {}'.format(
self.quant_metric))
# get quantized weight and act var name # get quantized weight and act var name
self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name
self.quantized_act_var_name = post_training_quantization._quantized_act_var_name self.quantized_act_var_name = post_training_quantization._quantized_act_var_name
self.support_quant_val_name_list = self.quantized_weight_var_name if not self.is_full_quantize else list( self.support_quant_val_name_list = self.quantized_weight_var_name if not self.is_full_quantize else list(
self.quantized_act_var_name) self.quantized_act_var_name)
self.weight_names = list(self.quantized_weight_var_name)
self.act_names = list(self.quantized_act_var_name)
executor.close() executor.close()
# load tobe_analyized_layer from checkpoint # load tobe_analyized_layer from checkpoint
...@@ -160,146 +144,110 @@ class AnalysisQuant(object): ...@@ -160,146 +144,110 @@ class AnalysisQuant(object):
list(self.quant_layer_metrics.keys())) list(self.quant_layer_metrics.keys()))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer)) self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
def compute_quant_sensitivity(self, plot_hist=True):
'''
compute the sensitivity of quantized layers by eval function
'''
assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
assert self.eval_function is not None, "When computing the sensitivity of quantized layers, the eval function is needed"
self.eval_quant_model()
self.sensitivity_ranklist = sorted(
self.quant_layer_metrics,
key=self.quant_layer_metrics.get,
reverse=False)
_logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist:
_logger.info("quant layer name: {}, eval metric: {}".format(
name, self.quant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist:
analysis_ret_f.write(
"quant layer name: {}, eval metric: {}\n".format(
name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
if plot_hist:
self.calculate_histogram()
def save_checkpoint(self): def save_checkpoint(self):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
with open(self.checkpoint_name, 'wb') as f: with open(self.checkpoint_name, 'wb') as f:
pickle.dump(self.quant_layer_metrics, f) pickle.dump(self.quant_layer_metrics, f)
_logger.info('save checkpoint to {}'.format(self.checkpoint_name)) _logger.info('Save checkpoint to {}.'.format(self.checkpoint_name))
def load_checkpoint(self): def load_checkpoint(self):
if not os.path.exists(self.checkpoint_name): if not os.path.exists(self.checkpoint_name):
_logger.info('Checkpoint path {} does not exist.'.format(
self.checkpoint_name))
return False return False
with open(self.checkpoint_name, 'rb') as f: with open(self.checkpoint_name, 'rb') as f:
self.quant_layer_metrics = pickle.load(f) self.quant_layer_metrics = pickle.load(f)
_logger.info('load checkpoint from {}'.format(self.checkpoint_name)) _logger.info('Load checkpoint from {}.'.format(self.checkpoint_name))
return True return True
def plot_activation_distribution(self, axis=None): def save_csv(self, data, save_name, csv_columns):
save_path = os.path.join(self.save_dir, save_name)
with open(save_path, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
writer.writeheader()
for d in data:
writer.writerow(d)
_logger.info('Activation Statistic is saved in {}'.format(save_path))
def create_ptq(self, executor, skip_tensor_list, algo):
return PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_tensor_list,
algo=algo, # avg fastest
onnx_format=self.onnx_format,
**self.ptq_config)
def sampling(self, executor, program, scope):
batch_id = 0
for data in self.data_loader():
executor.run(program=program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False,
scope=scope)
batch_id += 1
if batch_id >= self.batch_nums:
break
def eval_quant_model(self, skip_list):
executor = paddle.static.Executor(self.places)
post_training_quantization = self.create_ptq(
executor, skip_list, algo='avg')
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
executor.close()
return quant_metric
def metric_error_analyse(self):
''' '''
Collect and plot the distribution of the activation of each weight layer. Evaluate the quantized models, which are generated by quantizing each weight operator one by one. The results will be saved into analysis.txt.
''' '''
devices = paddle.device.get_device().split(':')[0] assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
places = paddle.device._convert_to_place(devices) assert self.eval_function is not None, "When computing the sensitivity of quantized layers, the eval function is needed"
executor = paddle.static.Executor(places)
# evaluate before quant
_logger.info('Start to evaluate the base model.')
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \ [program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \ self.model_dir, \
executor=executor, \ executor=executor, \
model_filename=self.model_filename, \ model_filename=self.model_filename, \
params_filename=self.params_filename) params_filename=self.params_filename)
self.base_metric = self.eval_function(executor, program, feed_list,
fetch_list)
_logger.info('Before quantized, the accuracy of the model is: {}'.
format(self.base_metric))
scope = global_scope() # evaluate before quant
_logger.info('Start to evaluate the quantized model.')
persistable_var_names = [] self.quant_metric = self.eval_quant_model(None)
for var in program.list_vars(): _logger.info('After quantized, the accuracy of the model is: {}'.format(
if var.persistable: self.quant_metric))
persistable_var_names.append(var.name)
weight_names = sorted(list(self.quantized_weight_var_name))
acts_weight_map = self.get_weight_act_map(program, weight_names,
persistable_var_names)
all_acts = list(acts_weight_map.keys())
all_weights = [acts_weight_map[act] for act in all_acts]
act_distribution = []
for var_name in all_acts:
var_tensor = load_variable_data(scope, var_name)
if axis is None:
var_tensor = var_tensor.flatten()
else:
var_tensor = var_tensor.reshape(
[-1, var_tensor.shape[axis]]).abs().max(axis=-1)
sample_num = len(var_tensor) if len(var_tensor) < 1000 else 1000
var_tensor = random.sample(list(var_tensor), sample_num)
act_distribution.append(var_tensor)
all_values = sum(act_distribution, [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(self.save_dir, 'activation_distribution.pdf')
with PdfPages(pdf_path) as pdf:
for i in range(0, len(act_distribution), 20):
r = i + 20 if i + 20 < len(act_distribution) else len(
act_distribution)
plt.boxplot(
act_distribution[i:r],
labels=all_weights[i:r],
showbox=True,
patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
plt.xlabel('Weight Name')
plt.ylabel("Activation Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Distribution plots is saved in {}'.format(pdf_path))
def eval_quant_model(self): # For each layer, quantize the weight op and evaluate the quantized model.
'''
For each layer, quantize the weight op and evaluate the quantized model.
'''
for i, layer_name in enumerate(self.tobe_analyized_layer): for i, layer_name in enumerate(self.tobe_analyized_layer):
_logger.info('Checking {}/{} quant model: quant layer {}'.format( _logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name)) i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.support_quant_val_name_list)) skip_list = copy.copy(list(self.support_quant_val_name_list))
skip_list.remove(layer_name) skip_list.remove(layer_name)
quant_metric = self.eval_quant_model(skip_list)
executor = paddle.static.Executor(self.places)
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_list,
onnx_format=self.onnx_format,
**self.ptq_config)
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
executor.close()
_logger.info( _logger.info(
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}". "Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name, format(layer_name,
...@@ -307,117 +255,261 @@ class AnalysisQuant(object): ...@@ -307,117 +255,261 @@ class AnalysisQuant(object):
round(self.base_metric - quant_metric, 4))) round(self.base_metric - quant_metric, 4)))
self.quant_layer_metrics[layer_name] = quant_metric self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint() self.save_checkpoint()
if self.onnx_format: if self.onnx_format:
self.temp_root_path.cleanup() self.temp_root_path.cleanup()
def get_weight_act_map(self, program, weight_names, persistable_var_names): self.sensitivity_ranklist = sorted(
act_names = {} self.quant_layer_metrics,
for op_name in weight_names: key=self.quant_layer_metrics.get,
for block_id in range(len(program.blocks)): reverse=False)
for op in program.blocks[block_id].ops:
var_name_list = _get_op_input_var_names(op)
if op_name in var_name_list:
for var_name in var_name_list:
if var_name not in persistable_var_names:
act_names[var_name] = op_name
return act_names
def get_hist_ops_name(self, graph, program):
if self.num_histogram_plots <= 0:
return []
best_weight_ops = self.sensitivity_ranklist[::-1][:self.
num_histogram_plots]
worst_weight_ops = self.sensitivity_ranklist[:self.num_histogram_plots]
persistable_var_names = [] _logger.info('Finished computing the sensitivity of the model.')
for var in program.list_vars(): for name in self.sensitivity_ranklist:
if var.persistable: _logger.info("quant layer name: {}, eval metric: {}".format(
persistable_var_names.append(var.name) name, self.quant_layer_metrics[name]))
best_acts = self.get_weight_act_map(program, best_weight_ops, analysis_file = os.path.join(self.save_dir, "analysis.txt")
persistable_var_names) with open(analysis_file, "w") as analysis_ret_f:
worst_acts = self.get_weight_act_map(program, worst_weight_ops, for name in self.sensitivity_ranklist:
persistable_var_names) analysis_ret_f.write(
return [best_weight_ops, best_acts, worst_weight_ops, worst_acts] "quant layer name: {}, eval metric: {}\n".format(
name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
def collect_tensors_histogram(self, scope, ops): def collect_vars(self, scope, var_names):
hist = {} all_vars = {}
for var_name in ops: for var_name in var_names:
var_tensor = load_variable_data(scope, var_name) var_tensor = load_variable_data(scope, var_name)
var_tensor = np.array(var_tensor) all_vars[var_name] = var_tensor
min_v = float(np.min(var_tensor)) return all_vars
max_v = float(np.max(var_tensor))
var_tensor = var_tensor.flatten()
_, hist_edges = np.histogram(
var_tensor.copy(),
bins=self.histogram_bins,
range=(min_v, max_v))
hist[var_name] = [var_tensor, hist_edges]
return hist
def calculate_histogram(self):
'''
Sample histograms for the weight and corresponding act tensors
'''
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(places)
def collect_base_stat(self):
_logger.info('Collecting Statistic Before PTQ...')
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \ [program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \ self.model_dir, \
executor=executor, \ executor=executor, \
model_filename=self.model_filename, \ model_filename=self.model_filename, \
params_filename=self.params_filename) params_filename=self.params_filename)
scope = global_scope()
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
self.acts_weight_map = self.get_weight_act_map(
program, self.weight_names, persistable_var_names)
activations_names = list(self.acts_weight_map.keys())
# sample
self.sampling(executor, program, scope)
before_act_data = self.collect_vars(scope, activations_names)
before_weight_data = self.collect_vars(scope, self.weight_names)
executor.close()
return before_act_data, before_weight_data
def collect_quant_stat(self):
_logger.info('Collecting Statistic After PTQ...')
executor = paddle.static.Executor(self.places)
scope = global_scope() scope = global_scope()
post_training_quantization = self.create_ptq(executor, None, algo='avg')
program = post_training_quantization.quantize()
graph = IrGraph(core.Graph(program.desc), for_test=False) persistable_var_names = []
tensors_tobe_draw_hist = self.get_hist_ops_name(graph, program) for var in program.list_vars():
if not tensors_tobe_draw_hist: if var.persistable:
return persistable_var_names.append(var.name)
quant_weight_names = self.weight_names
dequant_act_names = ["%s.quantized" % (n) for n in self.acts_weight_map]
for var in program.list_vars(): for var in program.list_vars():
if var.name in self.quantized_act_var_name: if var.name in dequant_act_names:
var.persistable = True var.persistable = True
# sample before collect histogram self.sampling(executor, program, scope)
batch_id = 0
for data in self.data_loader(): after_act_data = self.collect_vars(scope, dequant_act_names)
executor.run(program=program, after_weight_data = self.collect_vars(scope, quant_weight_names)
feed=data, executor.close()
fetch_list=fetch_list, return after_act_data, after_weight_data
return_numpy=False,
scope=scope) def statistical_analyse(self, analysis_axis=None):
batch_id += 1
if batch_id >= self.batch_nums: self.act_data, self.weight_data = self.collect_base_stat()
break self.quant_act_data, self.dequant_weight_data = self.collect_quant_stat(
)
pdf_names = [ fp_q_act_name_map = {
'best_weight_hist_result.pdf', n: "%s.quantized" % (n)
'best_act_hist_result.pdf', for n in self.acts_weight_map
'worst_weight_hist_result.pdf', }
'worst_act_hist_result.pdf', act_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
self.act_data,
self.quant_act_data,
fp_q_act_name_map,
is_weight=False,
axis=analysis_axis)
self.plot_box_distribution(box_fp_dist,
list(self.acts_weight_map.keys()),
'fp_activation_boxplot.pdf')
self.plot_box_distribution(box_q_dist,
list(self.acts_weight_map.keys()),
'quantized_activation_boxplot.pdf')
self.plot_hist_distribution(hist_fp_dist, 'fp_activation_histplot.pdf')
self.plot_hist_distribution(hist_q_dist,
'quantized_activation_histplot.pdf')
weight_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
self.weight_data,
self.dequant_weight_data,
None,
is_weight=True,
axis=analysis_axis)
self.plot_box_distribution(box_fp_dist,
list(self.quantized_weight_var_name),
'fp_weight_boxplot.pdf')
self.plot_box_distribution(box_q_dist,
list(self.quantized_weight_var_name),
'quantized_weight_boxplot.pdf')
self.plot_hist_distribution(hist_fp_dist, 'fp_weight_histplot.pdf')
self.plot_hist_distribution(hist_q_dist,
'quantized_weight_histplot.pdf')
statistic = act_statistic + weight_statistic
csv_columns = [
'Var Name', 'Var Type', 'Corresponding Weight Name', 'FP32 Min',
'FP32 Max', 'FP32 Mean', 'FP32 Std', 'Quantized Min',
'Quantized Max', 'Quantized Mean', 'Quantized Std', 'Diff Min',
'Diff Max', 'Diff Mean', 'Diff Std'
] ]
for tensors, save_pdf_name in zip(tensors_tobe_draw_hist, pdf_names): self.save_csv(statistic, 'statistic.csv', csv_columns)
if isinstance(tensors, list):
hist_data = self.collect_tensors_histogram(scope, tensors) def get_weight_act_map(self, program, weight_names, persistable_var_names):
self.draw_hist_pdf(hist_data, save_pdf_name, None) weight_act_map = {}
for op_name in weight_names:
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
var_name_list = _get_op_input_var_names(op)
if op_name in var_name_list:
for var_name in var_name_list:
if var_name not in persistable_var_names:
weight_act_map[var_name] = op_name
return weight_act_map
def collect_statistic(self,
fp_tensors,
quant_tensors,
var_name_map,
is_weight,
axis=None):
statistic = []
box_fp_dist, box_q_dist = [], []
hist_fp_dist, hist_q_dist = {}, {}
for var_name in fp_tensors:
fp_tensor = fp_tensors[var_name]
quant_name = var_name_map[
var_name] if var_name_map is not None else var_name
quant_tensor = quant_tensors[quant_name]
diff = fp_tensor - quant_tensor
fp_min = round(fp_tensor.min(), 4)
fp_max = round(fp_tensor.max(), 4)
fp_mean = round(fp_tensor.mean(), 4)
fp_std = round(fp_tensor.std(), 4)
q_min = round(quant_tensor.min(), 4)
q_max = round(quant_tensor.max(), 4)
q_mean = round(quant_tensor.mean(), 4)
q_std = round(quant_tensor.std(), 4)
diff_min = round(diff.min(), 4)
diff_max = round(diff.max(), 4)
diff_mean = round(diff.mean(), 4)
diff_std = round(diff.std(), 4)
stat = {
'Var Name': var_name,
'Var Type': 'Weight' if is_weight else 'Activation',
'Corresponding Weight Name': self.acts_weight_map[var_name]
if not is_weight else None,
'FP32 Min': fp_min,
'FP32 Max': fp_max,
'FP32 Mean': fp_mean,
'FP32 Std': fp_std,
'Quantized Min': q_min,
'Quantized Max': q_max,
'Quantized Mean': q_mean,
'Quantized Std': q_std,
'Diff Min': diff_min,
'Diff Max': diff_max,
'Diff Mean': diff_mean,
'Diff Std': diff_std,
}
statistic.append(stat)
# for boxplot
if axis is None:
box_fp_tensor = fp_tensor.flatten()
box_q_tensor = quant_tensor.flatten()
else: else:
hist_data = self.collect_tensors_histogram(scope, box_fp_tensor = fp_tensor.reshape(
list(tensors.keys())) [-1, fp_tensor.shape[axis]]).abs().max(axis=-1)
self.draw_hist_pdf(hist_data, save_pdf_name, tensors) box_q_tensor = quant_tensor.reshape(
[-1, quant_tensor.shape[axis]]).abs().max(axis=-1)
sample_num = len(box_fp_tensor) if len(
box_fp_tensor) < 1000 else 1000
box_fp_tensor = random.sample(list(box_fp_tensor), sample_num)
box_q_tensor = random.sample(list(box_q_tensor), sample_num)
box_fp_dist.append(box_fp_tensor)
box_q_dist.append(box_q_tensor)
# for histplot
_, hist_edges = np.histogram(
fp_tensor.copy(), bins=50, range=(fp_min, fp_max))
hist_fp_dist[var_name] = [fp_tensor.flatten(), hist_edges]
_, hist_edges = np.histogram(
quant_tensor.copy(), bins=50, range=(q_min, q_max))
hist_q_dist[quant_name] = [quant_tensor.flatten(), hist_edges]
return statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist
def plot_box_distribution(self, distribution, labels, save_name):
all_values = sum(distribution, [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(self.save_dir, save_name)
with PdfPages(pdf_path) as pdf:
for i in range(0, len(distribution), 20):
r = i + 20 if i + 20 < len(distribution) else len(distribution)
plt.boxplot(
distribution[i:r],
labels=labels[i:r],
showbox=True,
patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
if 'act' in save_name:
plt.xlabel('Activation Name')
else:
plt.xlabel('Weight Name')
plt.ylabel("Box Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Distribution plots is saved in {}'.format(pdf_path))
def draw_hist_pdf(self, hist_data, save_pdf_name, weight_act_map): def plot_hist_distribution(self, hist_data, save_name):
pdf_path = os.path.join(self.save_dir, save_pdf_name) pdf_path = os.path.join(self.save_dir, save_name)
with PdfPages(pdf_path) as pdf: with PdfPages(pdf_path) as pdf:
for name in hist_data: for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1]) plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name) plt.xlabel(name)
plt.ylabel("Frequency") plt.ylabel("Frequency")
if 'act' in save_pdf_name: if 'act' in save_name:
plt.title("Hist of Activation {}/Input of Weight {}".format( plt.title("Hist of Activation {}".format(name))
name, weight_act_map[name]))
else: else:
plt.title("Hist of Weight {}".format(name)) plt.title("Hist of Weight {}".format(name))
plt.show() plt.show()
...@@ -427,25 +519,30 @@ class AnalysisQuant(object): ...@@ -427,25 +519,30 @@ class AnalysisQuant(object):
def get_target_quant_model(self, target_metric): def get_target_quant_model(self, target_metric):
_logger.info( _logger.info(
'Start to Find quant model that satisfies the target metric.') 'Start to Find quantized model that satisfies the target metric.')
_logger.info( _logger.info(
'Make sure that you are using full eval dataset to get target quantized model.' 'Make sure that you are using full eval dataset to get target quantized model.'
) )
skip_list = [] skip_list = []
rank_list = copy.copy(self.sensitivity_ranklist) if self.quant_layer_metrics:
rank_list = sorted(
self.quant_layer_metrics,
key=self.quant_layer_metrics.get,
reverse=False)
else:
_logger.info(
'Analyse metric error before get target quantized model.')
self.metric_error_analyse()
while True: while True:
skip_list.append(rank_list.pop(0)) skip_list.append(rank_list.pop(0))
_logger.info('Skip Ops: {}'.format(skip_list)) _logger.info('Skip Ops: {}'.format(skip_list))
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
post_training_quantization = PostTrainingQuantization( post_training_quantization = self.create_ptq(
executor=executor, executor,
data_loader=self.data_loader, skip_list,
model_dir=self.model_dir, algo=self.ptq_config['algo']
model_filename=self.model_filename, if 'algo' in self.ptq_config else 'KL')
params_filename=self.params_filename,
onnx_format=self.onnx_format,
skip_tensor_list=skip_list,
**self.ptq_config)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
_logger.info('Evaluating...') _logger.info('Evaluating...')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册