未验证 提交 3b2ed2cf 编写于 作者: C Chang Xu 提交者: GitHub

Merge AnalysisPTQ & AnalysisQAT to Analysis (#1692)

上级 2bb09da6
# PTQ(Post Training Quantization)量化分析工具详细教程
# 量化分析工具详细教程
## 1. 量化分析工具功能
1. 统计分析(statistical_analyse):
......@@ -13,17 +13,18 @@
- 输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisPTQ 可传入参数解析
## 2. paddleslim.quant.Analysis 可传入参数解析
| **参数名** | **参数释义** |
|-----------------------------|-----------------------------------------|
| model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| float_model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| quant_model_dir | 默认为None,传入的量化模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可; 若不传入,分析工具将使用PTQ进行量化并分析|
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False|
| ptq_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |
| quant_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |
......@@ -45,7 +46,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
......@@ -72,12 +73,12 @@ image = paddle.static.data(
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......@@ -124,22 +125,17 @@ analyzer.statistical_analyse()
```shell
analyzer.metric_error_analyse()
```
调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
若不传入quant_model_dir,并且调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
若传入quant_model_dir,并且调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)
**直接产出符合预期精度的目标量化模型**
```shell
analyzer.get_target_quant_model(target_metric=70.0)
analyzer.get_target_quant_model(target_metric=0.70)
```
## 4. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。
## FAQ:
- 与QAT(Quantization-Aware Training)量化分析工具的区别:与QAT量化分析工具不同的是,PTQ量化分析工具则是加载待量化的原模型,对模型所有层依次进行量化,每次量化一层,进行验证获取精度误差分析。而QAT量化分析工具加载量化训练后的量化模型,遍历所有量化的层,依次去掉量化层,加载Float模型的参数,并进行验证获取精度误差分析。
- PTQ量化分析工具设计的原因:PTQ量化分析工具依次量化模型中的每一层,而不是依次去掉量化层是由于PTQ本身的高效性。依次量化一层进行验证,查看对模型精度的损失十分直观。
- 量化分析工具为什么要区分PTQ和QAT:实验证明PTQ和QAT后的量化模型的敏感层并不完全一致,将两种算法分开,敏感度分析结果更加准确。
# QAT(Quantization-Aware Training)量化分析工具详细教程
## 1. 量化分析工具功能
精度误差分析(metric_error_analyse):
- 遍历量化训练后模型的每层,去掉量化节点并计算当前层不量化的模型精度。该功能可以定位具体某层导致的量化损失。
## 2. paddleslim.quant.AnalysisQAT 可传入参数解析
| **参数名** | **参数释义** |
|-----------------------------|-----------------------------------------|
| quant_model_dir | 必须传入的量化后的模型文件路径 |
| float_model_dir | 必须传入的量化前的模型文件路径 |
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称 |
| quantizable_op_type | 需分析的量化的op类型,默认为`conv2d`, `depthwise_conv2d`, `mul` |
| qat_metric | 量化模型的精度,可不传入,默认为None,不传入时会自动计算 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False|
## 3. 量化分析工具的使用
**创建量化分析工具**
```shell
# 下载Inference模型
wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar -xf MobileNetV1_infer.tar
wget -q https://paddle-slim-models.bj.bcebos.com/act/MobileNetV1_QAT.tar
tar -xf MobileNetV1_QAT.tar
# 下载demo数据集
wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar -xf ILSVRC2012_data_demo.tar.gz
```
```shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisQAT(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_QAT",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
data_loader=train_loader)
```
**精度误差分析**
```shell
analyzer.metric_error_analyse()
```
调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)
## FAQ:
- 与PTQ(Post Training Quantization)量化分析工具的区别:与PTQ量化分析工具不同的是,QAT量化分析工具加载量化训练后的量化模型,遍历所有量化的层,依次去掉量化层,加载Float模型的参数,并进行验证获取精度误差分析。而PTQ量化分析工具则是加载待量化的原模型,对模型所有层依次进行量化,每次量化一层,进行验证获取精度误差分析。
- QAT量化分析工具设计的原因:QAT量化分析工具依次去掉量化层,而不是依次量化一层是由于QAT需要训练的特性。遍历每层进行量化训练再验证精度比较耗时,直接加载量化训练后的量化模型,依次去掉量化层更高效。
- 量化分析工具为什么要区分PTQ和QAT:实验证明PTQ和QAT后的量化模型的敏感层并不完全一致,将两种算法分开,敏感度分析结果更加准确。
......@@ -23,7 +23,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from keypoint_utils import keypoint_post_process
from post_process import PPYOLOEPostProcess
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
def argsparser():
......@@ -87,10 +87,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
elif isinstance(config['input_list'], dict):
if k in config['input_list'].keys():
data_input[config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'arch' in config and config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe,
......@@ -115,8 +116,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.log()
map_res = metric.get_results()
metric.reset()
map_key = 'keypoint' if 'arch' in config and config[
'arch'] == 'keypoint' else 'bbox'
map_key = 'keypoint' if 'arch' in config and config['arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
......@@ -127,9 +127,8 @@ def main():
ptq_config = config['PTQ']
# val dataset is sufficient for PTQ
data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'],
return_list=True)
data_loader = create('EvalReader')(
config['EvalDataset'], config['worker_num'], return_list=True)
ptq_data_loader = reader_wrapper(data_loader, config['input_list'])
# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
......@@ -139,10 +138,11 @@ def main():
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
val_loader = create('EvalReader')(
dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
global metric
if config['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
......@@ -161,14 +161,14 @@ def main():
else:
raise ValueError("metric currently only supports COCO and VOC.")
analyzer = AnalysisPTQ(
model_dir=config["model_dir"],
analyzer = Analysis(
float_model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=ptq_data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config,
quant_config=ptq_config,
resume=True, )
analyzer.statistical_analyse()
......
......@@ -21,7 +21,7 @@ from tqdm import tqdm
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset, COCOTrainDataset
from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
def argsparser():
......@@ -41,7 +41,8 @@ def argsparser():
'--resume',
type=bool,
default=False,
help="When break off while ananlyzing, could resume analysis program and load already analyzed information."
help=
"When break off while ananlyzing, could resume analysis program and load already analyzed information."
)
return parser
......@@ -54,10 +55,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
ncols=80) as t:
for data in val_loader:
data_all = {k: np.array(v) for k, v in data.items()}
outs = exe.run(compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
......@@ -103,15 +105,15 @@ def main():
load_onnx_model(config["model_dir"])
inference_model_path = config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
analyzer = AnalysisPTQ(
model_dir=inference_model_path,
analyzer = Analysis(
float_model_dir=inference_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
resume=FLAGS.resume,
ptq_config=ptq_config)
quant_config=ptq_config)
analyzer.statistical_analyse()
analyzer.metric_error_analyse()
......
......@@ -21,7 +21,7 @@ import time
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
from ppfleetx.data import build_dataloader
from ppfleetx.distributed.apis import env
from utils import parse_config
......@@ -164,17 +164,15 @@ def main():
global eval_loader
eval_loader = eval_reader_wrapper(valid_data_loader)
analyzer = AnalysisQAT(
analyzer = Analysis(
quant_model_dir=global_config["quant_model_dir"],
float_model_dir=global_config["float_model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
quantizable_op_type=global_config['quantizable_op_type'],
qat_metric=global_config['qat_metric']
if 'qat_metric' in global_config else None,
eval_function=eval_function,
data_loader=eval_loader,
save_dir=FLAGS.save_dir,
quant_config=all_config['quant_config'],
resume=global_config['resume'], )
analyzer.metric_error_analyse()
......
......@@ -5,11 +5,16 @@ Global:
float_model_dir: ./GPT_345M_Baseline
model_filename: model.pdmodel
params_filename: model.pdiparams
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
resume: False
reader_config: ./configs/gpt_reader.yaml
cloze_eval: True # True for LAMBADA Dataset; False for WikiText
quant_config:
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 8
batch_nums: 10
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,132 +17,111 @@ import sys
import pickle
import copy
import logging
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import csv
import numpy as np
import random
import tempfile
import paddle
import paddle.nn.functional as F
from ..core import GraphWrapper
from ..common import get_logger
from ..common import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir
from ..common import get_logger, load_inference_model
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.static.quantization import PostTrainingQuantization
from .analysis_utils import *
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["AnalysisPTQ"]
SUPPORT_WEIGHT_OP_DICT = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]]
}
class AnalysisPTQ(object):
class Analysis(object):
def __init__(self,
model_dir,
float_model_dir,
quant_model_dir=None,
model_filename=None,
params_filename=None,
eval_function=None,
data_loader=None,
save_dir='analysis_results',
eval_function=None,
resume=False,
ptq_config=None):
"""
AnalysisPTQ provides to analysis the sensitivity of each op in the model.
save_dir='analysis_results',
quant_config=None):
'''
Analysis provides to analysis the sensitivity of each op in the model.
Args:
model_dir(str): the path of fp32 model that will be quantized, it can also be '.onnx'
model_filename(str, optional): the model file name of the fp32 model
params_filename(str, optional): the parameter file name of the fp32 model
float_model_dir(str, required): the path of fp32 model, it can also be '.onnx'
quant_model_dir(str, optional):the path of quantized model, if is None, float model will be quantized by PTQ
model_filename(str, optional): the model file name of the fp32 and quantized model
params_filename(str, optional): the parameter file name of the fp32 and quantized model
eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model. (TODO: optional)
data_loader(Python Generator, Paddle.io.DataLoader, optional): the
Generator or Dataloader provides calibrate data, and it could
return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information
resume(bool, optional): When break off while ananlyzing, could resume analysis program and load already analyzed information.
ptq_config(dict, optional): the args that can initialize PostTrainingQuantization
"""
quant_config(dict, optional): the args that can initialize PostTrainingQuantization
Examples:
.. code-block:: python
from paddleslim.quant.analysis import Analysis
analyzer = Analysis(quant_model_dir=quant_model_dir)
analyzer.metric_error_analyse()
'''
if model_filename is None:
model_filename = 'model.pdmodel'
if params_filename is None:
params_filename = 'model.pdiparams'
self.model_dir = model_dir
self.float_model_dir = float_model_dir
self.quant_model_dir = quant_model_dir
self.model_filename = model_filename
self.params_filename = params_filename
self.histogram_bins = 1000
self.save_dir = save_dir
self.eval_function = eval_function
self.quant_layer_names = []
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.quant_layer_metrics = {}
self.ptq_config = ptq_config
self.batch_nums = ptq_config[
'batch_nums'] if 'batch_nums' in ptq_config else 10
self.is_full_quantize = ptq_config[
'is_full_quantize'] if 'is_full_quantize' in ptq_config else False
self.onnx_format = ptq_config[
'onnx_format'] if 'onnx_format' in ptq_config else False
ptq_config['onnx_format'] = self.onnx_format
if 'algo' not in ptq_config:
ptq_config['algo'] = 'avg'
self.data_loader = data_loader
self.eval_function = eval_function
self.quant_config = quant_config
self.batch_nums = quant_config.get("batch_nums", 10)
self.is_full_quantize = quant_config.get("is_full_quantize", False)
self.onnx_format = quant_config.get("onnx_format", False)
self.quantizable_op_type = quant_config.get(
"quantizable_op_type", list(SUPPORT_WEIGHT_OP_DICT.keys()))
self.skip_tensor_list = quant_config.get("skip_tensor_list", [])
if self.skip_tensor_list:
del self.quant_config['skip_tensor_list']
quant_config['onnx_format'] = self.onnx_format
quant_config['algo'] = quant_config.get("algo", 'avg')
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
if self.onnx_format:
self.temp_root_path = tempfile.TemporaryDirectory(dir=self.save_dir)
self.temp_save_path = os.path.join(self.temp_root_path.name, "ptq")
if not os.path.exists(self.temp_save_path):
os.makedirs(self.temp_save_path)
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
devices = paddle.device.get_device().split(':')[0]
self.places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(self.places)
# load model
[program, self.feed_list, self.fetch_list]= load_inference_model( \
self.model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
# create data_loader
self.data_loader = wrap_dataloader(data_loader, self.feed_list)
# quant model to get quantizable ops
post_training_quantization = self.create_ptq(executor, None)
_logger.info('Run PTQ before analysis.')
program = post_training_quantization.quantize()
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
# get quantized weight and act var name
self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name
self.quantized_act_var_name = post_training_quantization._quantized_act_var_name
self.support_quant_val_name_list = self.quantized_weight_var_name if not self.is_full_quantize else list(
self.quantized_act_var_name)
self.weight_names = list(self.quantized_weight_var_name)
self.act_names = list(self.quantized_act_var_name)
executor.close()
# load tobe_analyized_layer from checkpoint
self.layer_metrics = {}
if resume:
self.load_checkpoint()
self.tobe_analyized_layer = sorted(
list(self.support_quant_val_name_list))
def save_checkpoint(self):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
with open(self.checkpoint_name, 'wb') as f:
pickle.dump(self.quant_layer_metrics, f)
pickle.dump(self.layer_metrics, f)
_logger.info('Save checkpoint to {}.'.format(self.checkpoint_name))
def load_checkpoint(self):
......@@ -151,308 +130,117 @@ class AnalysisPTQ(object):
self.checkpoint_name))
return False
with open(self.checkpoint_name, 'rb') as f:
self.quant_layer_metrics = pickle.load(f)
self.layer_metrics = pickle.load(f)
_logger.info('Load checkpoint from {}.'.format(self.checkpoint_name))
return True
def save_csv(self, data, save_name, csv_columns):
save_path = os.path.join(self.save_dir, save_name)
with open(save_path, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
writer.writeheader()
for d in data:
writer.writerow(d)
_logger.info('Activation Statistic is saved in {}'.format(save_path))
def create_ptq(self, executor, skip_tensor_list):
return paddle.static.quantization.PostTrainingQuantization(
def get_weight_act_info(self, program, persistable=True):
self.persistable_var_names = []
for var in program.list_vars():
if var.persistable:
self.persistable_var_names.append(var.name)
graph = IrGraph(core.Graph(program.desc), for_test=True)
weight_act_dict = {}
act_weight_dict = {}
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self.quantizable_op_type:
in_x, in_y = SUPPORT_WEIGHT_OP_DICT[op_node.name()][0]
input_name_x = op_node.input(in_x)[0]
input_name_y = op_node.input(in_y)[0]
if not persistable:
weight_act_dict[input_name_y] = input_name_x
act_weight_dict[input_name_x] = input_name_y
else:
if input_name_y in self.persistable_var_names and input_name_y not in self.skip_tensor_list:
weight_act_dict[input_name_y] = input_name_x
act_weight_dict[input_name_x] = input_name_y
return weight_act_dict, act_weight_dict
def create_ptq(self, executor, skip_tensor_list=[]):
skip_tensor_list += self.skip_tensor_list
return PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
model_dir=self.float_model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_tensor_list,
**self.ptq_config)
**self.quant_config)
def sampling(self, executor, program, scope):
def sampling(self, executor, program, scope, fetch_list):
batch_id = 0
for data in self.data_loader():
executor.run(program=program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False,
scope=scope)
executor.run(
program=program,
feed=data,
fetch_list=fetch_list,
return_numpy=False,
scope=scope)
batch_id += 1
if batch_id >= self.batch_nums:
break
def fp_int_cosine_similarity(self, executor, float_program, quant_program,
float_scope, quant_scope):
cosine_similarity = []
for step, data in enumerate(self.data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(program=quant_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def get_sensitive_metric(self, skip_list, layer_name):
executor = paddle.static.Executor(self.places)
if self.eval_function is not None:
post_training_quantization = self.create_ptq(executor, skip_list)
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
if skip_list is None:
executor.close()
return metric
sensitive_metric = self.base_metric - metric
_logger.info(
"Quantized layer name: %s, the accuracy: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
else:
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_program, _, _] = load_inference_model(
self.model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
post_training_quantization = self.create_ptq(executor,
skip_list)
quant_program = post_training_quantization.quantize()
metric = self.fp_int_cosine_similarity(executor, float_program,
quant_program, float_scope,
quant_scope)
sensitive_metric = 1.0 - metric
_logger.info(
"Quantized layer name: %s, the cosine similarity: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
executor.close()
return sensitive_metric
def metric_error_analyse(self):
'''
Evaluate the quantized models, which are generated by quantizing each weight operator one by one. The results will be saved into analysis.txt.
'''
assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
if self.eval_function is not None:
# evaluate before quant
_logger.info('Start to evaluate the base model.')
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
self.base_metric = self.eval_function(executor, program, feed_list,
fetch_list)
_logger.info('Before quantized, the accuracy of the model is: {}'.
format(self.base_metric))
executor.close()
# evaluate before quant
_logger.info('Start to evaluate the quantized model.')
self.quant_metric = self.get_sensitive_metric(
None, 'all quantizable layers')
_logger.info('After quantized, the accuracy of the model is: {}'.
format(self.quant_metric))
# For each layer, quantize the weight op and evaluate the quantized model.
for i, layer_name in enumerate(self.tobe_analyized_layer):
if layer_name in self.quant_layer_metrics:
continue
_logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.support_quant_val_name_list))
skip_list.remove(layer_name)
sensitive_metric = self.get_sensitive_metric(skip_list, layer_name)
self.quant_layer_metrics[layer_name] = sensitive_metric
self.save_checkpoint()
if self.onnx_format:
self.temp_root_path.cleanup()
self.sensitivity_ranklist = sorted(
self.quant_layer_metrics,
key=self.quant_layer_metrics.get,
reverse=True)
_logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist:
_logger.info("Quantized layer name: {}, sensitivity metric: {}".
format(name, self.quant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist:
analysis_ret_f.write(
"Quantized layer name: {}, sensitivity metric: {}\n".format(
name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
def collect_vars(self, scope, var_names):
all_vars = {}
for var_name in var_names:
var_tensor = paddle.static.quantization.utils.load_variable_data(
scope, var_name)
all_vars[var_name] = var_tensor
return all_vars
def collect_base_stat(self):
_logger.info('Collecting Statistic Before PTQ...')
_logger.info('Collecting fp model statistic...')
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \
self.model_dir, \
self.float_model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
scope = paddle.static.global_scope()
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
self.acts_weight_map = self.get_weight_act_map(
program, self.weight_names, persistable_var_names)
activations_names = list(self.acts_weight_map.keys())
self.fp_weight_act_dict, self.fp_act_weight_dict = self.get_weight_act_info(
program)
self.fp_weight_names = list(self.fp_weight_act_dict.keys())
self.fp_act_names = list(self.fp_weight_act_dict.values())
for var in program.list_vars():
if var.name in activations_names:
if var.name in self.fp_act_names:
var.persistable = True
# sample
self.sampling(executor, program, scope)
before_act_data = self.collect_vars(scope, activations_names)
before_weight_data = self.collect_vars(scope, self.weight_names)
# sample
self.sampling(executor, program, scope, fetch_list)
fp_act = collect_vars(scope, self.fp_act_names)
fp_weight = collect_vars(scope, self.fp_weight_names)
executor.close()
return before_act_data, before_weight_data
return fp_act, fp_weight
def collect_quant_stat(self):
_logger.info('Collecting Statistic After PTQ...')
executor = paddle.static.Executor(self.places)
scope = paddle.static.global_scope()
post_training_quantization = self.create_ptq(executor, None)
program = post_training_quantization.quantize()
_logger.info('Collecting quant model statistic...')
if self.quant_model_dir is None:
executor = paddle.static.Executor(self.places)
scope = paddle.static.global_scope()
ptq = self.create_ptq(executor)
program = ptq.quantize()
feed_list, fetch_list = ptq._feed_list, ptq._fetch_list
else:
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \
self.quant_model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
scope = paddle.static.global_scope()
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
self.quant_weight_act_dict, self.quant_act_weight_dict = self.get_weight_act_info(
program)
self.quant_weight_names = list(self.quant_weight_act_dict.keys())
self.quant_act_names = list(self.quant_weight_act_dict.values())
quant_weight_names = self.weight_names
dequant_act_names = ["%s.quantized" % (n) for n in self.acts_weight_map]
for var in program.list_vars():
if var.name in dequant_act_names:
if var.name in self.quant_act_names:
var.persistable = True
self.sampling(executor, program, scope)
self.sampling(executor, program, scope, fetch_list)
after_act_data = self.collect_vars(scope, dequant_act_names)
after_weight_data = self.collect_vars(scope, quant_weight_names)
quant_act = collect_vars(scope, self.quant_act_names)
quant_weight = collect_vars(scope, self.quant_weight_names)
executor.close()
return after_act_data, after_weight_data
def statistical_analyse(self, analysis_axis=None):
self.act_data, self.weight_data = self.collect_base_stat()
self.quant_act_data, self.dequant_weight_data = self.collect_quant_stat(
)
fp_q_act_name_map = {
n: "%s.quantized" % (n)
for n in self.acts_weight_map
}
act_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
self.act_data,
self.quant_act_data,
fp_q_act_name_map,
is_weight=False,
axis=analysis_axis)
self.plot_box_distribution(box_fp_dist,
list(self.acts_weight_map.keys()),
'fp_activation_boxplot.pdf')
self.plot_box_distribution(box_q_dist,
list(self.acts_weight_map.keys()),
'quantized_activation_boxplot.pdf')
self.plot_hist_distribution(hist_fp_dist, 'fp_activation_histplot.pdf')
self.plot_hist_distribution(hist_q_dist,
'quantized_activation_histplot.pdf')
weight_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
self.weight_data,
self.dequant_weight_data,
None,
is_weight=True,
axis=analysis_axis)
self.plot_box_distribution(box_fp_dist,
list(self.quantized_weight_var_name),
'fp_weight_boxplot.pdf')
self.plot_box_distribution(box_q_dist,
list(self.quantized_weight_var_name),
'quantized_weight_boxplot.pdf')
self.plot_hist_distribution(hist_fp_dist, 'fp_weight_histplot.pdf')
self.plot_hist_distribution(hist_q_dist,
'quantized_weight_histplot.pdf')
statistic = act_statistic + weight_statistic
csv_columns = [
'Var Name', 'Var Type', 'Corresponding Weight Name', 'FP32 Min',
'FP32 Max', 'FP32 Mean', 'FP32 Std', 'Quantized Min',
'Quantized Max', 'Quantized Mean', 'Quantized Std', 'Diff Min',
'Diff Max', 'Diff Mean', 'Diff Std'
]
self.save_csv(statistic, 'statistic.csv', csv_columns)
def get_weight_act_map(self, program, weight_names, persistable_var_names):
weight_act_map = {}
for op_name in weight_names:
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
var_name_list = paddle.static.quantization.utils._get_op_input_var_names(
op)
if op_name in var_name_list:
for var_name in var_name_list:
if var_name not in persistable_var_names:
weight_act_map[var_name] = op_name
return weight_act_map
return quant_act, quant_weight
def collect_statistic(self,
fp_tensors,
......@@ -461,7 +249,7 @@ class AnalysisPTQ(object):
is_weight,
axis=None):
statistic = []
box_fp_dist, box_q_dist = [], []
box_fp_dist, box_q_dist = {}, {}
hist_fp_dist, hist_q_dist = {}, {}
fp_tensor_names = sorted(list(fp_tensors.keys()))
for var_name in fp_tensor_names:
......@@ -487,25 +275,39 @@ class AnalysisPTQ(object):
diff_std = round(diff.std(), 4)
stat = {
'Var Name': var_name,
'Var Type': 'Weight' if is_weight else 'Activation',
'Corresponding Weight Name': self.acts_weight_map[var_name]
if not is_weight else None,
'FP32 Min': fp_min,
'FP32 Max': fp_max,
'FP32 Mean': fp_mean,
'FP32 Std': fp_std,
'Quantized Min': q_min,
'Quantized Max': q_max,
'Quantized Mean': q_mean,
'Quantized Std': q_std,
'Diff Min': diff_min,
'Diff Max': diff_max,
'Diff Mean': diff_mean,
'Diff Std': diff_std,
'Var Name':
var_name,
'Var Type':
'Weight' if is_weight else 'Activation',
'Corresponding Weight Name':
self.fp_act_weight_dict[var_name] if not is_weight else None,
'FP32 Min':
fp_min,
'FP32 Max':
fp_max,
'FP32 Mean':
fp_mean,
'FP32 Std':
fp_std,
'Quantized Min':
q_min,
'Quantized Max':
q_max,
'Quantized Mean':
q_mean,
'Quantized Std':
q_std,
'Diff Min':
diff_min,
'Diff Max':
diff_max,
'Diff Mean':
diff_mean,
'Diff Std':
diff_std,
}
statistic.append(stat)
# for boxplot
# for boxplot
if axis is None:
box_fp_tensor = fp_tensor.flatten()
box_q_tensor = quant_tensor.flatten()
......@@ -514,12 +316,12 @@ class AnalysisPTQ(object):
[-1, fp_tensor.shape[axis]]).abs().max(axis=-1)
box_q_tensor = quant_tensor.reshape(
[-1, quant_tensor.shape[axis]]).abs().max(axis=-1)
sample_num = len(box_fp_tensor) if len(
box_fp_tensor) < 1000 else 1000
sample_num = len(
box_fp_tensor) if len(box_fp_tensor) < 1000 else 1000
box_fp_tensor = random.sample(list(box_fp_tensor), sample_num)
box_q_tensor = random.sample(list(box_q_tensor), sample_num)
box_fp_dist.append(box_fp_tensor)
box_q_dist.append(box_q_tensor)
box_fp_dist[var_name] = box_fp_tensor
box_q_dist[quant_name] = box_q_tensor
# for histplot
_, hist_edges = np.histogram(
......@@ -531,50 +333,253 @@ class AnalysisPTQ(object):
return statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist
def plot_box_distribution(self, distribution, labels, save_name):
all_values = sum(distribution, [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(self.save_dir, save_name)
with PdfPages(pdf_path) as pdf:
for i in range(0, len(distribution), 20):
r = i + 20 if i + 20 < len(distribution) else len(distribution)
plt.boxplot(
distribution[i:r],
labels=labels[i:r],
showbox=True,
patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
if 'act' in save_name:
plt.xlabel('Activation Name')
else:
plt.xlabel('Weight Name')
plt.ylabel("Box Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Distribution plots is saved in {}'.format(pdf_path))
def plot_hist_distribution(self, hist_data, save_name):
pdf_path = os.path.join(self.save_dir, save_name)
with PdfPages(pdf_path) as pdf:
for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name)
plt.ylabel("Probability")
locs, _ = plt.yticks()
plt.yticks(locs, np.round(locs / len(hist_data[name][0]), 3))
if 'act' in save_name:
plt.title("Hist of Activation {}".format(name))
else:
plt.title("Hist of Weight {}".format(name))
plt.show()
pdf.savefig()
plt.close()
_logger.info('Histogram plot is saved in {}'.format(pdf_path))
def statistical_analyse(self, analysis_axis=None):
fp_act, fp_weight = self.collect_base_stat()
quant_act, quant_weight = self.collect_quant_stat()
fp_q_act_dict = {
self.fp_weight_act_dict[n]: self.quant_weight_act_dict[n]
for n in self.fp_weight_act_dict
}
act_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
fp_act,
quant_act,
fp_q_act_dict,
is_weight=False,
axis=analysis_axis)
plot_box_distribution(box_fp_dist, self.save_dir,
'fp_activation_boxplot.pdf')
plot_box_distribution(box_q_dist, self.save_dir,
'quantized_activation_boxplot.pdf')
plot_hist_distribution(hist_fp_dist, self.save_dir,
'fp_activation_histplot.pdf')
plot_hist_distribution(hist_q_dist, self.save_dir,
'quantized_activation_histplot.pdf')
weight_statistic, box_fp_dist, box_q_dist, hist_fp_dist, hist_q_dist = self.collect_statistic(
fp_weight, quant_weight, None, is_weight=True, axis=analysis_axis)
plot_box_distribution(box_fp_dist, self.save_dir,
'fp_weight_boxplot.pdf')
plot_box_distribution(box_q_dist, self.save_dir,
'quantized_weight_boxplot.pdf')
plot_hist_distribution(hist_fp_dist, self.save_dir,
'fp_weight_histplot.pdf')
plot_hist_distribution(hist_q_dist, self.save_dir,
'quantized_weight_histplot.pdf')
statistic = act_statistic + weight_statistic
csv_columns = [
'Var Name', 'Var Type', 'Corresponding Weight Name', 'FP32 Min',
'FP32 Max', 'FP32 Mean', 'FP32 Std', 'Quantized Min',
'Quantized Max', 'Quantized Mean', 'Quantized Std', 'Diff Min',
'Diff Max', 'Diff Mean', 'Diff Std'
]
save_csv(statistic, self.save_dir, 'statistic.csv', csv_columns)
def get_quant_sensitive_metric(self, skip_list, layer_name):
executor = paddle.static.Executor(self.places)
if self.eval_function is not None:
ptq = self.create_ptq(executor, skip_list)
program = ptq.quantize()
_logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, feed_list, fetch_list = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
metric = self.eval_function(executor, program, ptq._feed_list,
ptq._fetch_list)
sensitive_metric = self.fp_metric - metric
_logger.info(
"Quantized layer name: %s, the accuracy: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
else:
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_program, float_feed_list,
float_fetch_list] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
ptq = self.create_ptq(executor, skip_list)
quant_program = ptq.quantize()
metric = fp_quant_cosine_similarity(
executor, self.data_loader, float_program, quant_program,
float_scope, quant_scope, float_fetch_list, ptq._fetch_list)
sensitive_metric = 1.0 - metric
_logger.info(
"Quantized layer name: %s, the cosine similarity: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
executor.close()
return sensitive_metric
def get_dequant_sensitive_metric(self, executor, float_scope, quant_scope,
layer_name):
weight_name = layer_name.split('.quantized.dequantized')[0]
with paddle.static.scope_guard(float_scope):
[float_program, float_feed_list,
float_fetch_list] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
[program, quant_feed_list, quant_fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
program_copy = program.clone()
graph = IrGraph(core.Graph(program_copy.desc), for_test=True)
input_rename_map, output_rename_map, removed_ops = get_new_in_out_map(
self.weight_act_dict[layer_name], graph, float_scope, quant_scope,
self.places)
saved_program = relink_graph(graph, input_rename_map, output_rename_map,
removed_ops)
if self.eval_function is not None:
with paddle.static.scope_guard(quant_scope):
_logger.info(
'Skip quant {}, evaluating....'.format(weight_name))
metric = self.eval_function(executor, saved_program,
quant_feed_list, quant_fetch_list)
sensitive_metric = self.quant_metric - metric
_logger.info(
'When skip quant %s, the eval metric is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, self.quant_metric - metric))
else:
metric = fp_quant_cosine_similarity(
executor, self.data_loader, float_program, saved_program,
float_scope, quant_scope, float_fetch_list, quant_fetch_list)
sensitive_metric = 1 - metric
_logger.info(
'When skip quant %s, the cosine similarity is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, 1 - metric))
return sensitive_metric
def prepare_error_analyse(self, dequant_layer_by_layer):
if not dequant_layer_by_layer:
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list]= load_inference_model( \
self.float_model_dir, \
executor=executor, \
model_filename=self.model_filename, \
params_filename=self.params_filename)
self.weight_act_dict, _ = self.get_weight_act_info(program)
self.support_quant_name_list = list(self.weight_act_dict.keys())
self.tobe_analyized_layer = sorted(
list(
set(self.support_quant_name_list) -
set(self.skip_tensor_list)))
if self.eval_function is not None:
_logger.info('Start to evaluate the FP model.')
self.fp_metric = self.eval_function(executor, program,
feed_list, fetch_list)
_logger.info(
'The accuracy of the FP model is: %.4f' % self.fp_metric)
executor.close()
_logger.info('Start to evaluate the quantized model.')
executor = paddle.static.Executor(self.places)
ptq = self.create_ptq(executor, self.skip_tensor_list)
program = ptq.quantize()
self.quant_metric = self.eval_function(executor, program,
feed_list, fetch_list)
_logger.info('The accuracy of the quantized model is: %.4f' %
self.quant_metric)
else:
executor = paddle.static.Executor(self.places)
[program, feed_list, fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
graph = IrGraph(core.Graph(program.desc), for_test=True)
self.weight_act_dict, _ = self.get_weight_act_info(
program, persistable=False)
if self.eval_function is not None:
_logger.info('Start to evaluate the quantized model.')
self.quant_metric = self.eval_function(executor, program,
feed_list, fetch_list)
_logger.info('The accuracy of the quantized model is: %.4f' %
self.quant_metric)
executor.close()
def metric_error_analyse(self):
assert self.data_loader is not None, \
"When computing the sensitivity of quantized layers, the data loader is needed"
dequant_layer_by_layer = False if self.quant_model_dir is None else True
self.prepare_error_analyse(dequant_layer_by_layer)
if not dequant_layer_by_layer:
_logger.info(
'For each layer, quantize the weight op and evaluate the quantized model.'
)
# For each layer, quantize the weight op and evaluate the quantized model.
for i, layer_name in enumerate(self.tobe_analyized_layer):
if layer_name in self.layer_metrics:
continue
_logger.info(
'Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.support_quant_name_list))
skip_list.remove(layer_name)
sensitive_metric = self.get_quant_sensitive_metric(
skip_list, layer_name)
self.layer_metrics[layer_name] = sensitive_metric
self.save_checkpoint()
if self.onnx_format:
self.temp_root_path.cleanup()
else:
_logger.info(
'For each layer, dequantize the weight op and evaluate the quantized model.'
)
executor = paddle.static.Executor(self.places)
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
for idx, name in enumerate(self.weight_act_dict):
weight_name = name.split('.quantized.dequantized')[0]
if weight_name in self.layer_metrics:
continue
_logger.info(
'Checking {}/{} quant model: without quant layer {}'.format(
idx + 1, len(self.weight_act_dict), weight_name))
sensitive_metric = self.get_dequant_sensitive_metric(
executor, float_scope, quant_scope, name)
self.layer_metrics[weight_name] = sensitive_metric
self.save_checkpoint()
executor.close()
self.sensitivity_ranklist = sorted(
self.layer_metrics, key=self.layer_metrics.get, reverse=True)
_logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist:
_logger.info("layer name: {}, sensitivity metric: {}".format(
name, self.layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist:
analysis_ret_f.write("layer name: {}, sensitivity metric: {}\n".
format(name, self.layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
def get_target_quant_model(self, target_metric):
_logger.info(
......@@ -583,11 +588,9 @@ class AnalysisPTQ(object):
'Make sure that you are using full eval dataset to get target quantized model.'
)
skip_list = []
if self.quant_layer_metrics:
if self.layer_metrics:
rank_list = sorted(
self.quant_layer_metrics,
key=self.quant_layer_metrics.get,
reverse=True)
self.layer_metrics, key=self.layer_metrics.get, reverse=True)
else:
_logger.info(
'Analyse metric error before get target quantized model.')
......@@ -597,12 +600,12 @@ class AnalysisPTQ(object):
skip_list.append(rank_list.pop(0))
_logger.info('Skip Ops: {}'.format(skip_list))
executor = paddle.static.Executor(self.places)
post_training_quantization = self.create_ptq(executor, skip_list)
program = post_training_quantization.quantize()
ptq = self.create_ptq(executor, skip_list)
program = ptq.quantize()
_logger.info('Evaluating...')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
quant_metric = self.eval_function(executor, program, ptq._feed_list,
ptq._fetch_list)
_logger.info("Current eval metric: {}, the target metric: {}".
format(quant_metric, target_metric))
if quant_metric >= target_metric:
......@@ -611,7 +614,7 @@ class AnalysisPTQ(object):
_logger.info(
'The quantized model satisfies the target metric and is saved to {}'.
format(quantize_model_path))
post_training_quantization.save_quantized_model(
ptq.save_quantized_model(
quantize_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import pickle
import copy
import logging
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.framework import core
from paddle.fluid.framework import IrGraph
from ..common import get_logger, load_inference_model
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["AnalysisQAT"]
class AnalysisQAT(object):
def __init__(self,
quant_model_dir,
float_model_dir,
model_filename=None,
params_filename=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
qat_metric=None,
eval_function=None,
data_loader=None,
save_dir='analysis_results',
resume=False):
'''
AnalysisQAT provides to analysis the sensitivity of each op in the model.
Args:
quant_model_dir(str): the path of INT8 model that quantized through QAT
float_model_dir(str): the path of FP32 model that is the base model of quant_model
model_filename(str, optional): the model file name of the model
params_filename(str, optional): the parameter file name of the model
quantizable_op_type(list of str, optional): the type of op that will be analyzed
qat_metric(float, optional): the metric of the quantized model, which will be calculated automatically if is None
eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model.
data_loader(Python Generator, Paddle.io.DataLoader, optional): the
Generator or Dataloader provides calibrate data, and it could
return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information
resume(bool, optional): When break off while ananlyzing, could resume analysis program and load already analyzed information.
'''
if model_filename is None:
model_filename = 'model.pdmodel'
if params_filename is None:
params_filename = 'model.pdiparams'
self.quant_model_dir = quant_model_dir
self.float_model_dir = float_model_dir
self.model_filename = model_filename
self.params_filename = params_filename
self.quantizable_op_type = quantizable_op_type
self.qat_metric = qat_metric
self.eval_function = eval_function
self.data_loader = data_loader
self.save_dir = save_dir
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.nonquant_layer_metrics = {}
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
devices = paddle.device.get_device().split(':')[0]
self.places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(self.places)
[program, self.feed_list, self.fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
_logger.info('Loaded model from: {}'.format(quant_model_dir))
graph = IrGraph(core.Graph(program.desc), for_test=True)
# find all inputs for each quantizable op
self.inputs_of_quantized_op = []
sorted_ops = graph.topology_sort()
for op_node in sorted_ops:
op_name = op_node.name()
if op_name in quantizable_op_type:
input_names = op_node.op().input_arg_names()
for input_name in input_names:
if 'quantized' in input_name:
self.inputs_of_quantized_op.append(input_names)
break
if self.eval_function is None:
assert self.data_loader is not None, "DataLoader cannot be None if Eval Fuction is None."
_logger.info(
'The sensitivity will measured by cosine similarity of the outputs from float model and quantized model.'
)
if self.qat_metric is None and self.eval_function is not None:
_logger.info('Calculating the metric of QAT model...')
self.qat_metric = self.eval_function(
executor, program, self.feed_list, self.fetch_list) * 100
_logger.info('The metric of QAT model is {}'.format(
round(self.qat_metric, 4)))
executor.close()
if resume:
self.load_checkpoint()
def save_checkpoint(self):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
with open(self.checkpoint_name, 'wb') as f:
pickle.dump(self.nonquant_layer_metrics, f)
_logger.info('Save checkpoint to {}.'.format(self.checkpoint_name))
def load_checkpoint(self):
if not os.path.exists(self.checkpoint_name):
_logger.info('Checkpoint path {} does not exist.'.format(
self.checkpoint_name))
return False
with open(self.checkpoint_name, 'rb') as f:
self.nonquant_layer_metrics = pickle.load(f)
_logger.info('Load checkpoint from {}.'.format(self.checkpoint_name))
return True
def get_weight_name(self, inputs_names):
# TODO(xc)
w_idx = 0 if 'w_0' in inputs_names[0] else 1
weight_name = inputs_names[w_idx].split('.quantized.dequantized')[0]
return weight_name
def get_new_in_out_map(
self,
input_list,
graph,
float_scope,
quant_scope, ):
input_rename_map = {}
output_rename_map = {}
removed_ops = []
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
in_names = op_node.input_arg_names()
out_names = op_node.output_arg_names()
if len(out_names) == 1 and out_names[0] in input_list:
in_var = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0])
if not in_var.persistable():
# act
for op in graph.all_op_nodes():
o_ns = op.output_arg_names()
if len(o_ns) == 1 and o_ns[0] == in_var.name():
in_var_1 = graph._find_node_by_name(
op.inputs, op.input('X')[0])
graph.safe_remove_nodes(op)
removed_ops.append(op.id())
input_rename_map[out_var.node] = in_var_1
else:
# weight
with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array(
float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, self.places)
input_rename_map[out_var.node] = in_var
graph.safe_remove_nodes(op_node)
removed_ops.append(op_node.id())
output_rename_map[in_var.node] = out_var
return input_rename_map, output_rename_map, removed_ops
def relink_graph(self, graph, input_rename_map, output_rename_map,
removed_ops):
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
for var in op_node.inputs:
if var.node in input_rename_map:
old_in = var
new_in = input_rename_map[var.node]
graph.update_input_link(old_in, new_in, op_node)
_logger.info(
f'relink {op_node.name()} \'s input node from {old_in.name()} to {new_in.name()}.'
)
for var in op_node.outputs:
if var.node in output_rename_map:
old_out = var
new_out = output_rename_map[var.node]
graph.update_input_link(old_out, new_out, op_node)
_logger.info(
f'relink {op_node.name()} \'s output node from {old_out.name()} to {new_out.name()}.'
)
return graph.to_program()
def fp_int_cosine_similarity(self, executor, float_program, quant_program,
float_scope, quant_scope):
cosine_similarity = []
for step, data in enumerate(self.data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.float_fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(program=quant_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def metric_error_analyse(self):
executor = paddle.static.Executor(self.places)
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
for idx, input_list in enumerate(self.inputs_of_quantized_op):
weight_name = self.get_weight_name(input_list)
if weight_name in self.nonquant_layer_metrics:
continue
_logger.info(
'Checking {}/{} quant model: without quant layer {}'.format(
idx + 1, len(self.inputs_of_quantized_op), weight_name))
with paddle.static.scope_guard(float_scope):
[float_program, self.float_feed_list,
self.float_fetch_list] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
[program, self.feed_list,
self.fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
program_copy = program.clone()
graph = IrGraph(core.Graph(program_copy.desc), for_test=True)
input_rename_map, output_rename_map, removed_ops = self.get_new_in_out_map(
input_list, graph, float_scope, quant_scope)
saved_program = self.relink_graph(graph, input_rename_map,
output_rename_map, removed_ops)
if self.eval_function is not None:
with paddle.static.scope_guard(quant_scope):
_logger.info('Skip quant {}, evaluating....'.format(
weight_name))
metric = self.eval_function(executor, saved_program,
self.feed_list,
self.fetch_list) * 100
self.nonquant_layer_metrics[
weight_name] = metric - self.qat_metric
_logger.info(
'When skip quant %s, the eval metric is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, metric - self.qat_metric))
else:
metric = self.fp_int_cosine_similarity(executor, float_program,
saved_program,
float_scope, quant_scope)
self.nonquant_layer_metrics[weight_name] = 1 - metric
_logger.info(
'When skip quant %s, the cosine similarity is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, 1 - metric))
self.save_checkpoint()
executor.close()
self.sensitivity_ranklist = sorted(
self.nonquant_layer_metrics,
key=self.nonquant_layer_metrics.get,
reverse=True)
_logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist:
_logger.info("Without quant layer name: {}, sensitive metric: {}".
format(name, self.nonquant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist:
analysis_ret_f.write(
"Without quant layer name: {}, sensitive metric: {}\n".
format(name, self.nonquant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
import os
import sys
import csv
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from ..common import get_logger
import paddle
import paddle.nn.functional as F
from paddle.static.quantization.utils import load_variable_data
_logger = get_logger(__name__, level=logging.INFO)
def collect_vars(scope, var_names):
all_vars = {}
for var_name in var_names:
var_tensor = load_variable_data(scope, var_name)
all_vars[var_name] = var_tensor
return all_vars
def plot_box_distribution(box_data, save_dir, save_name):
all_values = sum(list(box_data.values()), [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(save_dir, save_name)
labels = sorted(box_data.keys())
with PdfPages(pdf_path) as pdf:
for i in range(0, len(labels), 20):
r = i + 20 if i + 20 < len(labels) else len(labels)
dist = [box_data[n] for n in labels[i:r]]
plt.boxplot(
dist, labels=labels[i:r], showbox=True, patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
if 'act' in save_name:
plt.xlabel('Activation Name')
else:
plt.xlabel('Weight Name')
plt.ylabel("Box Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Box plots is saved in {}'.format(pdf_path))
def plot_hist_distribution(hist_data, save_dir, save_name):
pdf_path = os.path.join(save_dir, save_name)
with PdfPages(pdf_path) as pdf:
for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name)
plt.ylabel("Probability")
locs, _ = plt.yticks()
plt.yticks(locs, np.round(locs / len(hist_data[name][0]), 3))
if 'act' in save_name:
plt.title("Hist of Activation {}".format(name))
else:
plt.title("Hist of Weight {}".format(name))
plt.show()
pdf.savefig()
plt.close()
_logger.info('Histogram plot is saved in {}'.format(pdf_path))
def save_csv(data, save_dir, save_name, csv_columns):
save_path = os.path.join(save_dir, save_name)
with open(save_path, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
writer.writeheader()
for d in data:
writer.writerow(d)
_logger.info('Activation Statistic is saved in {}'.format(save_path))
def fp_quant_cosine_similarity(executor, data_loader, float_program,
quant_program, float_scope, quant_scope,
float_fetch_list, quant_fetch_list):
cosine_similarity = []
for step, data in enumerate(data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(
program=float_program,
feed=data,
fetch_list=float_fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(
program=quant_program,
feed=data,
fetch_list=quant_fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def get_new_in_out_map(input_name, graph, float_scope, quant_scope, place):
input_rename_map = {}
output_rename_map = {}
removed_ops = []
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
in_names = op_node.input_arg_names()
out_names = op_node.output_arg_names()
if out_names[0] == input_name:
in_var = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0])
if not in_var.persistable():
# act
for op in graph.all_op_nodes():
o_ns = op.output_arg_names()
if len(o_ns) == 1 and o_ns[0] == in_var.name():
in_var_1 = graph._find_node_by_name(
op.inputs, op.input('X')[0])
graph.safe_remove_nodes(op)
removed_ops.append(op.id())
input_rename_map[out_var.node] = in_var_1
else:
# weight
with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array(
float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, place)
input_rename_map[out_var.node] = in_var
graph.safe_remove_nodes(op_node)
removed_ops.append(op_node.id())
output_rename_map[in_var.node] = out_var
return input_rename_map, output_rename_map, removed_ops
def relink_graph(graph, input_rename_map, output_rename_map, removed_ops):
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
for var in op_node.inputs:
if var.node in input_rename_map:
old_in = var
new_in = input_rename_map[var.node]
graph.update_input_link(old_in, new_in, op_node)
_logger.info(
f'relink {op_node.name()} \'s input node from {old_in.name()} to {new_in.name()}.'
)
for var in op_node.outputs:
if var.node in output_rename_map:
old_out = var
new_out = output_rename_map[var.node]
graph.update_input_link(old_out, new_out, op_node)
_logger.info(
f'relink {op_node.name()} \'s output node from {old_out.name()} to {new_out.name()}.'
)
return graph.to_program()
......@@ -7,7 +7,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -17,7 +17,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
......@@ -51,12 +52,12 @@ class AnalysisPTQDemo(unittest.TestCase):
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......
......@@ -8,7 +8,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -19,7 +19,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
......@@ -52,9 +53,9 @@ class ImageNetDataset(DatasetFolder):
return len(self.samples)
class AnalysisPTQEvalFunction(unittest.TestCase):
class AnalysisEvalFunction(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisPTQEvalFunction, self).__init__(*args, **kwargs)
super(AnalysisEvalFunction, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
......@@ -116,9 +117,10 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = exe.run(
compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
......@@ -135,12 +137,13 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = exe.run(
compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
......@@ -148,12 +151,12 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
result = np.mean(np.array(results), axis=0)
return result[0]
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......@@ -164,7 +167,7 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
analyzer.get_target_quant_model(69.5)
analyzer.get_target_quant_model(0.695)
os.system('rm -rf MobileNetV1_analysis')
......
......@@ -8,7 +8,7 @@ from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.static.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -19,7 +19,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
......@@ -55,8 +56,8 @@ class AnalysisQATDemo(unittest.TestCase):
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
......@@ -83,12 +84,13 @@ class AnalysisQATDemo(unittest.TestCase):
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_quant",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="analysis_result",
quant_config=ptq_config,
data_loader=train_loader)
analyzer.metric_error_analyse()
os.system('rm -rf analysis_result')
......
......@@ -8,7 +8,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -21,7 +21,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
......@@ -118,9 +119,10 @@ class AnalysisQATEvalFunction(unittest.TestCase):
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = exe.run(
compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
......@@ -137,12 +139,13 @@ class AnalysisQATEvalFunction(unittest.TestCase):
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = exe.run(
compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
......@@ -150,8 +153,8 @@ class AnalysisQATEvalFunction(unittest.TestCase):
result = np.mean(np.array(results), axis=0)
return result[0]
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
......@@ -178,12 +181,13 @@ class AnalysisQATEvalFunction(unittest.TestCase):
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_QAT",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
quant_config=ptq_config,
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册