未验证 提交 3b2ed2cf 编写于 作者: C Chang Xu 提交者: GitHub

Merge AnalysisPTQ & AnalysisQAT to Analysis (#1692)

上级 2bb09da6
# PTQ(Post Training Quantization)量化分析工具详细教程
# 量化分析工具详细教程
## 1. 量化分析工具功能
1. 统计分析(statistical_analyse):
......@@ -13,17 +13,18 @@
- 输入预期精度,直接产出符合预期精度的量化模型。
## 2. paddleslim.quant.AnalysisPTQ 可传入参数解析
## 2. paddleslim.quant.Analysis 可传入参数解析
| **参数名** | **参数释义** |
|-----------------------------|-----------------------------------------|
| model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| float_model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| quant_model_dir | 默认为None,传入的量化模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可; 若不传入,分析工具将使用PTQ进行量化并分析|
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False|
| ptq_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |
| quant_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |
......@@ -45,7 +46,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
......@@ -72,12 +73,12 @@ image = paddle.static.data(
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......@@ -124,22 +125,17 @@ analyzer.statistical_analyse()
```shell
analyzer.metric_error_analyse()
```
调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
若不传入quant_model_dir,并且调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
若传入quant_model_dir,并且调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)
**直接产出符合预期精度的目标量化模型**
```shell
analyzer.get_target_quant_model(target_metric=70.0)
analyzer.get_target_quant_model(target_metric=0.70)
```
## 4. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。
## FAQ:
- 与QAT(Quantization-Aware Training)量化分析工具的区别:与QAT量化分析工具不同的是,PTQ量化分析工具则是加载待量化的原模型,对模型所有层依次进行量化,每次量化一层,进行验证获取精度误差分析。而QAT量化分析工具加载量化训练后的量化模型,遍历所有量化的层,依次去掉量化层,加载Float模型的参数,并进行验证获取精度误差分析。
- PTQ量化分析工具设计的原因:PTQ量化分析工具依次量化模型中的每一层,而不是依次去掉量化层是由于PTQ本身的高效性。依次量化一层进行验证,查看对模型精度的损失十分直观。
- 量化分析工具为什么要区分PTQ和QAT:实验证明PTQ和QAT后的量化模型的敏感层并不完全一致,将两种算法分开,敏感度分析结果更加准确。
# QAT(Quantization-Aware Training)量化分析工具详细教程
## 1. 量化分析工具功能
精度误差分析(metric_error_analyse):
- 遍历量化训练后模型的每层,去掉量化节点并计算当前层不量化的模型精度。该功能可以定位具体某层导致的量化损失。
## 2. paddleslim.quant.AnalysisQAT 可传入参数解析
| **参数名** | **参数释义** |
|-----------------------------|-----------------------------------------|
| quant_model_dir | 必须传入的量化后的模型文件路径 |
| float_model_dir | 必须传入的量化前的模型文件路径 |
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称 |
| quantizable_op_type | 需分析的量化的op类型,默认为`conv2d`, `depthwise_conv2d`, `mul` |
| qat_metric | 量化模型的精度,可不传入,默认为None,不传入时会自动计算 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False|
## 3. 量化分析工具的使用
**创建量化分析工具**
```shell
# 下载Inference模型
wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar -xf MobileNetV1_infer.tar
wget -q https://paddle-slim-models.bj.bcebos.com/act/MobileNetV1_QAT.tar
tar -xf MobileNetV1_QAT.tar
# 下载demo数据集
wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar -xf ILSVRC2012_data_demo.tar.gz
```
```shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisQAT(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_QAT",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
data_loader=train_loader)
```
**精度误差分析**
```shell
analyzer.metric_error_analyse()
```
调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)
## FAQ:
- 与PTQ(Post Training Quantization)量化分析工具的区别:与PTQ量化分析工具不同的是,QAT量化分析工具加载量化训练后的量化模型,遍历所有量化的层,依次去掉量化层,加载Float模型的参数,并进行验证获取精度误差分析。而PTQ量化分析工具则是加载待量化的原模型,对模型所有层依次进行量化,每次量化一层,进行验证获取精度误差分析。
- QAT量化分析工具设计的原因:QAT量化分析工具依次去掉量化层,而不是依次量化一层是由于QAT需要训练的特性。遍历每层进行量化训练再验证精度比较耗时,直接加载量化训练后的量化模型,依次去掉量化层更高效。
- 量化分析工具为什么要区分PTQ和QAT:实验证明PTQ和QAT后的量化模型的敏感层并不完全一致,将两种算法分开,敏感度分析结果更加准确。
......@@ -23,7 +23,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from keypoint_utils import keypoint_post_process
from post_process import PPYOLOEPostProcess
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
def argsparser():
......@@ -87,10 +87,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
elif isinstance(config['input_list'], dict):
if k in config['input_list'].keys():
data_input[config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'arch' in config and config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe,
......@@ -115,8 +116,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.log()
map_res = metric.get_results()
metric.reset()
map_key = 'keypoint' if 'arch' in config and config[
'arch'] == 'keypoint' else 'bbox'
map_key = 'keypoint' if 'arch' in config and config['arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
......@@ -127,9 +127,8 @@ def main():
ptq_config = config['PTQ']
# val dataset is sufficient for PTQ
data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'],
return_list=True)
data_loader = create('EvalReader')(
config['EvalDataset'], config['worker_num'], return_list=True)
ptq_data_loader = reader_wrapper(data_loader, config['input_list'])
# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
......@@ -139,10 +138,11 @@ def main():
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
val_loader = create('EvalReader')(
dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
global metric
if config['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
......@@ -161,14 +161,14 @@ def main():
else:
raise ValueError("metric currently only supports COCO and VOC.")
analyzer = AnalysisPTQ(
model_dir=config["model_dir"],
analyzer = Analysis(
float_model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=ptq_data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config,
quant_config=ptq_config,
resume=True, )
analyzer.statistical_analyse()
......
......@@ -21,7 +21,7 @@ from tqdm import tqdm
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset, COCOTrainDataset
from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
def argsparser():
......@@ -41,7 +41,8 @@ def argsparser():
'--resume',
type=bool,
default=False,
help="When break off while ananlyzing, could resume analysis program and load already analyzed information."
help=
"When break off while ananlyzing, could resume analysis program and load already analyzed information."
)
return parser
......@@ -54,10 +55,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
ncols=80) as t:
for data in val_loader:
data_all = {k: np.array(v) for k, v in data.items()}
outs = exe.run(compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
......@@ -103,15 +105,15 @@ def main():
load_onnx_model(config["model_dir"])
inference_model_path = config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
analyzer = AnalysisPTQ(
model_dir=inference_model_path,
analyzer = Analysis(
float_model_dir=inference_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
resume=FLAGS.resume,
ptq_config=ptq_config)
quant_config=ptq_config)
analyzer.statistical_analyse()
analyzer.metric_error_analyse()
......
......@@ -21,7 +21,7 @@ import time
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
from ppfleetx.data import build_dataloader
from ppfleetx.distributed.apis import env
from utils import parse_config
......@@ -164,17 +164,15 @@ def main():
global eval_loader
eval_loader = eval_reader_wrapper(valid_data_loader)
analyzer = AnalysisQAT(
analyzer = Analysis(
quant_model_dir=global_config["quant_model_dir"],
float_model_dir=global_config["float_model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
quantizable_op_type=global_config['quantizable_op_type'],
qat_metric=global_config['qat_metric']
if 'qat_metric' in global_config else None,
eval_function=eval_function,
data_loader=eval_loader,
save_dir=FLAGS.save_dir,
quant_config=all_config['quant_config'],
resume=global_config['resume'], )
analyzer.metric_error_analyse()
......
......@@ -5,11 +5,16 @@ Global:
float_model_dir: ./GPT_345M_Baseline
model_filename: model.pdmodel
params_filename: model.pdiparams
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
resume: False
reader_config: ./configs/gpt_reader.yaml
cloze_eval: True # True for LAMBADA Dataset; False for WikiText
quant_config:
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 8
batch_nums: 10
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import pickle
import copy
import logging
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.framework import core
from paddle.fluid.framework import IrGraph
from ..common import get_logger, load_inference_model
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["AnalysisQAT"]
class AnalysisQAT(object):
def __init__(self,
quant_model_dir,
float_model_dir,
model_filename=None,
params_filename=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
qat_metric=None,
eval_function=None,
data_loader=None,
save_dir='analysis_results',
resume=False):
'''
AnalysisQAT provides to analysis the sensitivity of each op in the model.
Args:
quant_model_dir(str): the path of INT8 model that quantized through QAT
float_model_dir(str): the path of FP32 model that is the base model of quant_model
model_filename(str, optional): the model file name of the model
params_filename(str, optional): the parameter file name of the model
quantizable_op_type(list of str, optional): the type of op that will be analyzed
qat_metric(float, optional): the metric of the quantized model, which will be calculated automatically if is None
eval_function(function): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of quantized model.
data_loader(Python Generator, Paddle.io.DataLoader, optional): the
Generator or Dataloader provides calibrate data, and it could
return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information
resume(bool, optional): When break off while ananlyzing, could resume analysis program and load already analyzed information.
'''
if model_filename is None:
model_filename = 'model.pdmodel'
if params_filename is None:
params_filename = 'model.pdiparams'
self.quant_model_dir = quant_model_dir
self.float_model_dir = float_model_dir
self.model_filename = model_filename
self.params_filename = params_filename
self.quantizable_op_type = quantizable_op_type
self.qat_metric = qat_metric
self.eval_function = eval_function
self.data_loader = data_loader
self.save_dir = save_dir
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.nonquant_layer_metrics = {}
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
devices = paddle.device.get_device().split(':')[0]
self.places = paddle.device._convert_to_place(devices)
executor = paddle.static.Executor(self.places)
[program, self.feed_list, self.fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
_logger.info('Loaded model from: {}'.format(quant_model_dir))
graph = IrGraph(core.Graph(program.desc), for_test=True)
# find all inputs for each quantizable op
self.inputs_of_quantized_op = []
sorted_ops = graph.topology_sort()
for op_node in sorted_ops:
op_name = op_node.name()
if op_name in quantizable_op_type:
input_names = op_node.op().input_arg_names()
for input_name in input_names:
if 'quantized' in input_name:
self.inputs_of_quantized_op.append(input_names)
break
if self.eval_function is None:
assert self.data_loader is not None, "DataLoader cannot be None if Eval Fuction is None."
_logger.info(
'The sensitivity will measured by cosine similarity of the outputs from float model and quantized model.'
)
if self.qat_metric is None and self.eval_function is not None:
_logger.info('Calculating the metric of QAT model...')
self.qat_metric = self.eval_function(
executor, program, self.feed_list, self.fetch_list) * 100
_logger.info('The metric of QAT model is {}'.format(
round(self.qat_metric, 4)))
executor.close()
if resume:
self.load_checkpoint()
def save_checkpoint(self):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
with open(self.checkpoint_name, 'wb') as f:
pickle.dump(self.nonquant_layer_metrics, f)
_logger.info('Save checkpoint to {}.'.format(self.checkpoint_name))
def load_checkpoint(self):
if not os.path.exists(self.checkpoint_name):
_logger.info('Checkpoint path {} does not exist.'.format(
self.checkpoint_name))
return False
with open(self.checkpoint_name, 'rb') as f:
self.nonquant_layer_metrics = pickle.load(f)
_logger.info('Load checkpoint from {}.'.format(self.checkpoint_name))
return True
def get_weight_name(self, inputs_names):
# TODO(xc)
w_idx = 0 if 'w_0' in inputs_names[0] else 1
weight_name = inputs_names[w_idx].split('.quantized.dequantized')[0]
return weight_name
def get_new_in_out_map(
self,
input_list,
graph,
float_scope,
quant_scope, ):
input_rename_map = {}
output_rename_map = {}
removed_ops = []
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
in_names = op_node.input_arg_names()
out_names = op_node.output_arg_names()
if len(out_names) == 1 and out_names[0] in input_list:
in_var = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0])
if not in_var.persistable():
# act
for op in graph.all_op_nodes():
o_ns = op.output_arg_names()
if len(o_ns) == 1 and o_ns[0] == in_var.name():
in_var_1 = graph._find_node_by_name(
op.inputs, op.input('X')[0])
graph.safe_remove_nodes(op)
removed_ops.append(op.id())
input_rename_map[out_var.node] = in_var_1
else:
# weight
with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array(
float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, self.places)
input_rename_map[out_var.node] = in_var
graph.safe_remove_nodes(op_node)
removed_ops.append(op_node.id())
output_rename_map[in_var.node] = out_var
return input_rename_map, output_rename_map, removed_ops
def relink_graph(self, graph, input_rename_map, output_rename_map,
removed_ops):
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
for var in op_node.inputs:
if var.node in input_rename_map:
old_in = var
new_in = input_rename_map[var.node]
graph.update_input_link(old_in, new_in, op_node)
_logger.info(
f'relink {op_node.name()} \'s input node from {old_in.name()} to {new_in.name()}.'
)
for var in op_node.outputs:
if var.node in output_rename_map:
old_out = var
new_out = output_rename_map[var.node]
graph.update_input_link(old_out, new_out, op_node)
_logger.info(
f'relink {op_node.name()} \'s output node from {old_out.name()} to {new_out.name()}.'
)
return graph.to_program()
def fp_int_cosine_similarity(self, executor, float_program, quant_program,
float_scope, quant_scope):
cosine_similarity = []
for step, data in enumerate(self.data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.float_fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(program=quant_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def metric_error_analyse(self):
executor = paddle.static.Executor(self.places)
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
for idx, input_list in enumerate(self.inputs_of_quantized_op):
weight_name = self.get_weight_name(input_list)
if weight_name in self.nonquant_layer_metrics:
continue
_logger.info(
'Checking {}/{} quant model: without quant layer {}'.format(
idx + 1, len(self.inputs_of_quantized_op), weight_name))
with paddle.static.scope_guard(float_scope):
[float_program, self.float_feed_list,
self.float_fetch_list] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
[program, self.feed_list,
self.fetch_list] = load_inference_model(
self.quant_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
program_copy = program.clone()
graph = IrGraph(core.Graph(program_copy.desc), for_test=True)
input_rename_map, output_rename_map, removed_ops = self.get_new_in_out_map(
input_list, graph, float_scope, quant_scope)
saved_program = self.relink_graph(graph, input_rename_map,
output_rename_map, removed_ops)
if self.eval_function is not None:
with paddle.static.scope_guard(quant_scope):
_logger.info('Skip quant {}, evaluating....'.format(
weight_name))
metric = self.eval_function(executor, saved_program,
self.feed_list,
self.fetch_list) * 100
self.nonquant_layer_metrics[
weight_name] = metric - self.qat_metric
_logger.info(
'When skip quant %s, the eval metric is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, metric - self.qat_metric))
else:
metric = self.fp_int_cosine_similarity(executor, float_program,
saved_program,
float_scope, quant_scope)
self.nonquant_layer_metrics[weight_name] = 1 - metric
_logger.info(
'When skip quant %s, the cosine similarity is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, 1 - metric))
self.save_checkpoint()
executor.close()
self.sensitivity_ranklist = sorted(
self.nonquant_layer_metrics,
key=self.nonquant_layer_metrics.get,
reverse=True)
_logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist:
_logger.info("Without quant layer name: {}, sensitive metric: {}".
format(name, self.nonquant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist:
analysis_ret_f.write(
"Without quant layer name: {}, sensitive metric: {}\n".
format(name, self.nonquant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file))
import os
import sys
import csv
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from ..common import get_logger
import paddle
import paddle.nn.functional as F
from paddle.static.quantization.utils import load_variable_data
_logger = get_logger(__name__, level=logging.INFO)
def collect_vars(scope, var_names):
all_vars = {}
for var_name in var_names:
var_tensor = load_variable_data(scope, var_name)
all_vars[var_name] = var_tensor
return all_vars
def plot_box_distribution(box_data, save_dir, save_name):
all_values = sum(list(box_data.values()), [])
max_value = np.max(all_values)
min_value = np.min(all_values)
pdf_path = os.path.join(save_dir, save_name)
labels = sorted(box_data.keys())
with PdfPages(pdf_path) as pdf:
for i in range(0, len(labels), 20):
r = i + 20 if i + 20 < len(labels) else len(labels)
dist = [box_data[n] for n in labels[i:r]]
plt.boxplot(
dist, labels=labels[i:r], showbox=True, patch_artist=True)
plt.xticks(rotation=90)
plt.tick_params(axis='x')
plt.ylim([min_value, max_value])
if 'act' in save_name:
plt.xlabel('Activation Name')
else:
plt.xlabel('Weight Name')
plt.ylabel("Box Distribution")
plt.tight_layout()
plt.show()
pdf.savefig()
plt.close()
_logger.info('Box plots is saved in {}'.format(pdf_path))
def plot_hist_distribution(hist_data, save_dir, save_name):
pdf_path = os.path.join(save_dir, save_name)
with PdfPages(pdf_path) as pdf:
for name in hist_data:
plt.hist(hist_data[name][0], bins=hist_data[name][1])
plt.xlabel(name)
plt.ylabel("Probability")
locs, _ = plt.yticks()
plt.yticks(locs, np.round(locs / len(hist_data[name][0]), 3))
if 'act' in save_name:
plt.title("Hist of Activation {}".format(name))
else:
plt.title("Hist of Weight {}".format(name))
plt.show()
pdf.savefig()
plt.close()
_logger.info('Histogram plot is saved in {}'.format(pdf_path))
def save_csv(data, save_dir, save_name, csv_columns):
save_path = os.path.join(save_dir, save_name)
with open(save_path, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
writer.writeheader()
for d in data:
writer.writerow(d)
_logger.info('Activation Statistic is saved in {}'.format(save_path))
def fp_quant_cosine_similarity(executor, data_loader, float_program,
quant_program, float_scope, quant_scope,
float_fetch_list, quant_fetch_list):
cosine_similarity = []
for step, data in enumerate(data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(
program=float_program,
feed=data,
fetch_list=float_fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(
program=quant_program,
feed=data,
fetch_list=quant_fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def get_new_in_out_map(input_name, graph, float_scope, quant_scope, place):
input_rename_map = {}
output_rename_map = {}
removed_ops = []
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
in_names = op_node.input_arg_names()
out_names = op_node.output_arg_names()
if out_names[0] == input_name:
in_var = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0])
if not in_var.persistable():
# act
for op in graph.all_op_nodes():
o_ns = op.output_arg_names()
if len(o_ns) == 1 and o_ns[0] == in_var.name():
in_var_1 = graph._find_node_by_name(
op.inputs, op.input('X')[0])
graph.safe_remove_nodes(op)
removed_ops.append(op.id())
input_rename_map[out_var.node] = in_var_1
else:
# weight
with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array(
float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, place)
input_rename_map[out_var.node] = in_var
graph.safe_remove_nodes(op_node)
removed_ops.append(op_node.id())
output_rename_map[in_var.node] = out_var
return input_rename_map, output_rename_map, removed_ops
def relink_graph(graph, input_rename_map, output_rename_map, removed_ops):
for op_node in graph.all_op_nodes():
if op_node.id() in removed_ops:
continue
for var in op_node.inputs:
if var.node in input_rename_map:
old_in = var
new_in = input_rename_map[var.node]
graph.update_input_link(old_in, new_in, op_node)
_logger.info(
f'relink {op_node.name()} \'s input node from {old_in.name()} to {new_in.name()}.'
)
for var in op_node.outputs:
if var.node in output_rename_map:
old_out = var
new_out = output_rename_map[var.node]
graph.update_input_link(old_out, new_out, op_node)
_logger.info(
f'relink {op_node.name()} \'s output node from {old_out.name()} to {new_out.name()}.'
)
return graph.to_program()
......@@ -7,7 +7,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -17,7 +17,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
......@@ -51,12 +52,12 @@ class AnalysisPTQDemo(unittest.TestCase):
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......
......@@ -8,7 +8,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -19,7 +19,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
......@@ -52,9 +53,9 @@ class ImageNetDataset(DatasetFolder):
return len(self.samples)
class AnalysisPTQEvalFunction(unittest.TestCase):
class AnalysisEvalFunction(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisPTQEvalFunction, self).__init__(*args, **kwargs)
super(AnalysisEvalFunction, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
......@@ -116,9 +117,10 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = exe.run(
compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
......@@ -135,12 +137,13 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = exe.run(
compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
......@@ -148,12 +151,12 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
result = np.mean(np.array(results), axis=0)
return result[0]
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
......@@ -164,7 +167,7 @@ class AnalysisPTQEvalFunction(unittest.TestCase):
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
analyzer.get_target_quant_model(69.5)
analyzer.get_target_quant_model(0.695)
os.system('rm -rf MobileNetV1_analysis')
......
......@@ -8,7 +8,7 @@ from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.static.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
paddle.enable_static()
......@@ -19,7 +19,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
......@@ -55,8 +56,8 @@ class AnalysisQATDemo(unittest.TestCase):
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
......@@ -83,12 +84,13 @@ class AnalysisQATDemo(unittest.TestCase):
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_quant",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="analysis_result",
quant_config=ptq_config,
data_loader=train_loader)
analyzer.metric_error_analyse()
os.system('rm -rf analysis_result')
......
......@@ -8,7 +8,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -21,7 +21,8 @@ class ImageNetDataset(DatasetFolder):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
......@@ -118,9 +119,10 @@ class AnalysisQATEvalFunction(unittest.TestCase):
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = exe.run(
compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
......@@ -137,12 +139,13 @@ class AnalysisQATEvalFunction(unittest.TestCase):
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = exe.run(
compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
......@@ -150,8 +153,8 @@ class AnalysisQATEvalFunction(unittest.TestCase):
result = np.mean(np.array(results), axis=0)
return result[0]
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
......@@ -178,12 +181,13 @@ class AnalysisQATEvalFunction(unittest.TestCase):
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_QAT",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
quant_config=ptq_config,
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册