未验证 提交 71e0acb1 编写于 作者: C Chang Xu 提交者: GitHub

Add New Func in Quant Analysis (#1439)

上级 9650ce42
......@@ -2,8 +2,9 @@
## 1. 量化分析工具功能
1. 遍历模型所有层,依次量化该层,计算量化后精度。为所有只量化一层的模型精度排序,可视化不适合量化的层,以供量化时可选择性跳过不适合量化的层。
2. 可视化量化效果最好和最差的层的权重和激活分布图,以供分析模型量化效果的原因。
3. 【敬请期待】输入预期精度,直接产出符合预期精度的量化模型。
2. 可视化激活箱状图,以供分析每个可量化OP的激活分布对量化效果的影响。
3. 量化效果较好和较差的层的权重和激活直方分布图,以供分析其对量化效果的影响。
4. 输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisQuant 可传入参数解析
```yaml
......@@ -30,7 +31,37 @@ ptq_config
## 3. 量化分析工具产出内容
## 3. 量化分析工具的使用
1. 创建量化分析工具:
```
analyzer = AnalysisQuant(
model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
ptq_config=config['PTQ'])
```
2. 绘制所有可量化层的激活箱状图
```
analyzer.plot_activation_distribution()
```
以检测模型中的picodet-s为例,从以下激活箱状图(部分层)中可以发现,`conv2d_7.w_0``conv2d_9.w_0` 这两层的激活输入有大量离群点,会导致量化效果较差。
<p align="center">
<img src="./detection/images/act_distribution.png" width=849 hspace='10'/> <br />
</p>
3. 计算每层的量化敏感度并且绘制直方分布图
```
analyzer.compute_quant_sensitivity(plot_hist=True)
```
`plot_hist` 默认为True,如不需要获得量化效果较好和较差的层的权重和激活分布图,可设置为False。
量化分析工具会默认会产出以下目录:
```
......@@ -41,9 +72,20 @@ analysis_results/
├── worst_weight_hist_result.pdf
├── worst_act_hist_result.pdf
```
- 所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
- 通过设置参数`num_histogram_plots`,可选择绘出该数量个量化效果最好和最差层的weight和activation的直方分布图,将以PDF形式保存在 `./analysis_results` 文件夹下, 分别保存为 `best_weight_hist_result.pdf``best_act_hist_result.pdf``worst_weight_hist_result.pdf``worst_act_hist_result.pdf` 中以供对比分析。
以检测模型中的picodet-s为例,从`analysis.txt`可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失。这一现象符合对激活箱状图的观察。
<p align="center">
<img src="./detection/images/picodet_analysis.png" width=849 hspace='10'/> <br />
</p>
4. 直接产出符合预期精度的量化模型
```
analyzer.get_target_quant_model(target_metric)
```
## 3. 根据分析结果执行离线量化
## 4. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。
......@@ -138,18 +138,35 @@ python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml
python analysis.py --config_path=./configs/picodet_s_analysis.yaml
```
如下图,经过量化分析之后,可以发现`conv2d_1.w_0``conv2d_3.w_0``conv2d_5.w_0``conv2d_7.w_0``conv2d_9.w_0` 这些层会导致较大的精度损失,这些层均为主干网络中靠前部分的`depthwise_conv`
<p align="center">
<img src="./images/picodet_analysis.png" width=849 hspace='10'/> <br />
</p>
在保存的 `activation_distribution.pdf` 中,也可以发现以上这些层的 `activation` 存在较多离群点,导致量化效果较差。
<p align="center">
<img src="./images/act_distribution.png" width=849 hspace='10'/> <br />
</p>
经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [picodet_s_analyzed_ptq.yaml](./configs/picodet_s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升24.9个点。
```shell
python post_quant.py --config_path=./configs/picodet_s_analyzed_ptq.yaml --save_dir=./picodet_s_analyzed_ptq_out
```
如想分析之后直接产出符合目标精度的量化模型,可在 `picodet_s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `FastEvalDataset` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `FastEvalDataset` 字段删掉即可。
注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。
量化分析工具详细介绍见[量化分析工具介绍](../analysis.md)
## 4.预测部署
预测部署可参考[Detection模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/detection)
......
......@@ -126,12 +126,16 @@ def main():
config = load_config(FLAGS.config_path)
ptq_config = config['PTQ']
# val dataset is sufficient for PTQ
data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'],
return_list=True)
data_loader = reader_wrapper(data_loader, config['input_list'])
ptq_data_loader = reader_wrapper(data_loader, config['input_list'])
dataset = config['EvalDataset']
# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
dataset = config[
'FastEvalDataset'] if 'FastEvalDataset' in config else config[
'EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size'])
......@@ -162,10 +166,23 @@ def main():
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=data_loader,
data_loader=ptq_data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config)
analyzer.analysis()
# plot the boxplot of activations of quantizable weights
analyzer.plot_activation_distribution()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']:
if 'FastEvalDataset' in config:
# change fast_val_loader to full val_loader
val_loader = data_loader
# get the quantized model that satisfies target metric you set
analyzer.get_target_quant_model(target_metric=config['target_metric'])
if __name__ == '__main__':
......
......@@ -5,6 +5,9 @@ params_filename: model.pdiparams
save_dir: ./analysis_results
metric: COCO
num_classes: 80
plot_hist: True
get_target_quant_model: False
target_metric: None
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
......@@ -15,18 +18,21 @@ PTQ:
batch_nums: 10
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: /dataset/coco/
# Small Dataset to accelerate analysis
# If not exist, delete the dict of FastEvalDataset
FastEvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/small_instances_val2017.json
dataset_dir: /dataset/coco/
eval_height: &eval_height 416
eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width]
......@@ -41,7 +47,7 @@ EvalReader:
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 32
batch_size: 1
input_list: ['image']
arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE
arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE
model_dir: ./ppyoloe_crn_s_300e_coco
model_filename: model.pdmodel
params_filename: model.pdiparams
save_dir: ./analysis_results_ppyoloe
metric: COCO
num_classes: 80
plot_hist: True
get_target_quant_model: False
target_metric: None
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
......@@ -14,19 +17,20 @@ PTQ:
is_full_quantize: False
batch_size: 32
batch_nums: 10
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: /paddle/dataset/coco/
# Small Dataset to accelerate analysis
# If not exist, delete the dict of FastEvalDataset
FastEvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/small_instances_val2017.json
dataset_dir: /dataset/coco/
worker_num: 0
......@@ -38,4 +42,4 @@ EvalReader:
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 32
\ No newline at end of file
batch_size: 1
\ No newline at end of file
......@@ -38,8 +38,10 @@
#### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim > 2.3版本
- X2Paddle >= 1.3.9
- opencv-python
(1)安装paddlepaddle:
```shell
# CPU
......@@ -139,10 +141,21 @@ python analysis.py --config_path=./configs/yolov6s_analysis.yaml
经此分析,在进行离线量化时,可以跳过这些导致精度下降较多的层,可使用 [yolov6s_analyzed_ptq.yaml](./configs/yolov6s_analyzed_ptq.yaml),然后再次进行离线量化。跳过这些层后,离线量化精度上升9.4个点。
```shell
python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_dir=./yolov6s_analyzed_ptq_out
```
如想分析之后直接产出符合目标精度的量化模型,可在 `yolov6s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `fast_val_anno_path` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `fast_val_anno_path` 设置为None即可。
注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。
量化分析工具详细介绍见[量化分析工具介绍](../analysis.md)
## 4.预测部署
预测部署可参考[YOLO系列模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/pytorch_yolo_series)
......
......@@ -18,7 +18,7 @@ import numpy as np
import argparse
import paddle
from tqdm import tqdm
from post_process import YOLOv6PostProcess, coco_metric
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset, COCOTrainDataset
from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant.analysis import AnalysisQuant
......@@ -53,7 +53,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOv6PostProcess(
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
bboxes_list.append(res['bbox'])
......@@ -72,6 +72,8 @@ def main():
input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images'
# val dataset is sufficient for PTQ
dataset = COCOTrainDataset(
dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'],
......@@ -81,10 +83,12 @@ def main():
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
global val_loader
# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
dataset = COCOValDataset(
dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'],
anno_path=config['val_anno_path'])
anno_path=config['fast_val_anno_path'] if
config['fast_val_anno_path'] is not None else config['val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
......@@ -101,7 +105,30 @@ def main():
data_loader=data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config)
analyzer.analysis()
# plot the boxplot of activations of quantizable weights
analyzer.plot_activation_distribution()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=config['plot_hist'])
if config['get_target_quant_model']:
if config['fast_val_anno_path'] is not None:
# change fast_val_loader to full val_loader
dataset = COCOValDataset(
dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'],
anno_path=config['val_anno_path'])
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
dataset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=0)
# get the quantized model that satisfies target metric you set
analyzer.get_target_quant_model(config['target_metric'])
if __name__ == '__main__':
......
......@@ -4,6 +4,11 @@ save_dir: ./analysis_results
dataset_dir: /dataset/coco/
val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis
fast_val_anno_path: annotations/small_instances_val2017.json # if not exist, please set None
get_target_quant_model: False
target_metric: None
plot_hist: True
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
......
arch: YOLOv7
model_dir: ./yolov7.onnx
save_dir: ./analysis_results
dataset_dir: /dataset/coco/
val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis
fast_val_anno_path: annotations/small_instances_val2017.json # if not exist, please set None
get_target_quant_model: False
target_metric: None
plot_hist: True
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
\ No newline at end of file
arch: YOLOv7
model_dir: ./yolov7s.onnx
model_dir: ./yolov7.onnx
dataset_dir: /dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
......
......@@ -20,7 +20,7 @@ import logging
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import random
import paddle
from paddle.fluid import core
from paddle.fluid import framework
......@@ -105,7 +105,7 @@ class AnalysisQuant(object):
if self.eval_function is not None:
self.base_metric = self.eval_function(
executor, program, self.feed_list, self.fetch_list)
_logger.info('before quantized, the accuracy of the model is: {}'.
_logger.info('Before quantized, the accuracy of the model is: {}'.
format(self.base_metric))
# quant and evaluate after quant (skip_list = None)
......@@ -121,7 +121,7 @@ class AnalysisQuant(object):
program = post_training_quantization.quantize()
self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list)
_logger.info('after quantized, the accuracy of the model is: {}'.format(
_logger.info('After quantized, the accuracy of the model is: {}'.format(
self.quant_metric))
# get quantized weight and act var name
......@@ -135,8 +135,13 @@ class AnalysisQuant(object):
list(self.quant_layer_metrics.keys()))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
def analysis(self):
self.compute_quant_sensitivity()
def compute_quant_sensitivity(self, plot_hist=True):
'''
compute the sensitivity of quantized layers by eval function
'''
assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
assert self.eval_function is not None, "When computing the sensitivity of quantized layers, the eval function is needed"
self.eval_quant_model()
self.sensitivity_ranklist = sorted(
self.quant_layer_metrics,
key=self.quant_layer_metrics.get,
......@@ -154,7 +159,9 @@ class AnalysisQuant(object):
"quant layer name: {}, eval metric: {}\n".format(
name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
self.calculate_histogram()
if plot_hist:
self.calculate_histogram()
def save_checkpoint(self):
if not os.path.exists(self.save_dir):
......@@ -171,12 +178,76 @@ class AnalysisQuant(object):
_logger.info('load checkpoint from {}'.format(self.checkpoint_name))
return True
def compute_quant_sensitivity(self):
def plot_activation_distribution(self, axis=None):
'''
Collect and plot the distribution of the activation of each weight layer.
'''
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(places)
[program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
weight_names = sorted(list(self.quantized_weight_var_name))
acts_weight_map = self.get_weight_act_map(program, weight_names,
persistable_var_names)
all_acts = list(acts_weight_map.keys())
all_weights = [acts_weight_map[act] for act in all_acts]
act_distribution = []
for var_name in all_acts:
var_tensor = load_variable_data(scope, var_name)
if axis is None:
var_tensor = var_tensor.flatten()
else:
var_tensor = var_tensor.reshape(
[-1, var_tensor.shape[axis]]).abs().max(axis=-1)
sample_num = len(var_tensor) if len(var_tensor) < 1000 else 1000
var_tensor = random.sample(list(var_tensor), sample_num)
act_distribution.append(var_tensor)
all_values = sum(act_distribution, [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(self.save_dir, 'activation_distribution.pdf')
with PdfPages(pdf_path) as pdf:
for i in range(0, len(act_distribution), 20):
r = i + 20 if i + 20 < len(act_distribution) else len(
act_distribution)
plt.boxplot(
act_distribution[i:r],
labels=all_weights[i:r],
showbox=True,
patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
plt.xlabel('Weight Name')
plt.ylabel("Activation Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Distribution plots is saved in {}'.format(pdf_path))
def eval_quant_model(self):
'''
For each layer, quantize the weight op and evaluate the quantized model.
'''
for i, layer_name in enumerate(self.tobe_analyized_layer):
_logger.info('checking {}/{} quant model: quant layer {}'.format(
_logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.quantized_weight_var_name))
skip_list.remove(layer_name)
......@@ -198,15 +269,14 @@ class AnalysisQuant(object):
self.fetch_list)
executor.close()
_logger.info(
"quant layer name: {}, eval metric: {}, the loss caused by this layer: {}".
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name, quant_metric, self.base_metric -
quant_metric))
self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint()
def get_act_name_by_weight(self, program, weight_names,
persistable_var_names):
act_ops_names = []
def get_weight_act_map(self, program, weight_names, persistable_var_names):
act_names = {}
for op_name in weight_names:
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
......@@ -214,8 +284,8 @@ class AnalysisQuant(object):
if op_name in var_name_list:
for var_name in var_name_list:
if var_name not in persistable_var_names:
act_ops_names.append(var_name)
return act_ops_names
act_names[var_name] = op_name
return act_names
def get_hist_ops_name(self, graph, program):
if self.num_histogram_plots <= 0:
......@@ -230,13 +300,13 @@ class AnalysisQuant(object):
if var.persistable:
persistable_var_names.append(var.name)
best_act_ops = self.get_act_name_by_weight(program, best_weight_ops,
persistable_var_names)
worst_act_ops = self.get_act_name_by_weight(program, worst_weight_ops,
persistable_var_names)
return [best_weight_ops, best_act_ops, worst_weight_ops, worst_act_ops]
best_acts = self.get_weight_act_map(program, best_weight_ops,
persistable_var_names)
worst_acts = self.get_weight_act_map(program, worst_weight_ops,
persistable_var_names)
return [best_weight_ops, best_acts, worst_weight_ops, worst_acts]
def collect_ops_histogram(self, scope, ops):
def collect_tensors_histogram(self, scope, ops):
hist = {}
for var_name in ops:
var_tensor = load_variable_data(scope, var_name)
......@@ -268,8 +338,8 @@ class AnalysisQuant(object):
scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False)
ops_tobe_draw_hist = self.get_hist_ops_name(graph, program)
if not ops_tobe_draw_hist:
tensors_tobe_draw_hist = self.get_hist_ops_name(graph, program)
if not tensors_tobe_draw_hist:
return
for var in program.list_vars():
......@@ -294,19 +364,72 @@ class AnalysisQuant(object):
'worst_weight_hist_result.pdf',
'worst_act_hist_result.pdf',
]
for ops, save_pdf_name in zip(ops_tobe_draw_hist, pdf_names):
hist_data = self.collect_ops_histogram(scope, ops)
self.draw_pdf(hist_data, save_pdf_name)
def draw_pdf(self, hist_data, save_pdf_name):
for tensors, save_pdf_name in zip(tensors_tobe_draw_hist, pdf_names):
if isinstance(tensors, list):
hist_data = self.collect_tensors_histogram(scope, tensors)
self.draw_hist_pdf(hist_data, save_pdf_name, None)
else:
hist_data = self.collect_tensors_histogram(scope,
list(tensors.keys()))
self.draw_hist_pdf(hist_data, save_pdf_name, tensors)
def draw_hist_pdf(self, hist_data, save_pdf_name, weight_act_map):
pdf_path = os.path.join(self.save_dir, save_pdf_name)
with PdfPages(pdf_path) as pdf:
for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name)
plt.ylabel("frequency")
plt.title("Hist of variable {}".format(name))
plt.ylabel("Frequency")
if 'act' in save_pdf_name:
plt.title("Hist of Activation {}/Input of Weight {}".format(
name, weight_act_map[name]))
else:
plt.title("Hist of Weight {}".format(name))
plt.show()
pdf.savefig()
plt.close()
_logger.info('Histogram plot is saved in {}'.format(pdf_path))
def get_target_quant_model(self, target_metric):
_logger.info(
'Start to Find quant model that satisfies the target metric.')
_logger.info(
'Make sure that you are using full eval dataset to get target quantized model.'
)
skip_list = []
rank_list = copy.copy(self.sensitivity_ranklist)
while True:
skip_list.append(rank_list.pop(0))
_logger.info('Skip Ops: {}'.format(skip_list))
executor = paddle.static.Executor(self.places)
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_list,
**self.ptq_config)
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
_logger.info("Current eval metric: {}, the target metric: {}".
format(quant_metric, target_metric))
if quant_metric >= target_metric:
quantize_model_path = os.path.join(self.save_dir,
'target_quant_model')
_logger.info(
'The quantized model satisfies the target metric and is saved to {}'.
format(quantize_model_path))
post_training_quantization.save_quantized_model(
quantize_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
break
else:
_logger.info(
'The quantized model does not satisfy the target metric. Skip next Op...'
)
executor.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册