未验证 提交 43c5156b 编写于 作者: G Guanghua Yu 提交者: GitHub

update quant analysis (#1455)

上级 75d006bf
......@@ -37,6 +37,12 @@ def argsparser():
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--resume',
type=bool,
default=False,
help="When break off while ananlyzing, could resume analysis program and load already analyzed information."
)
return parser
......@@ -104,6 +110,7 @@ def main():
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
resume=FLAGS.resume,
ptq_config=ptq_config)
# plot the boxplot of activations of quantizable weights
......
arch: YOLOv7
model_dir: ./yolov7.onnx
save_dir: ./analysis_results
dataset_dir: /dataset/coco/
dataset_dir: dataset/coco/
val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis
......@@ -15,5 +15,6 @@ PTQ:
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
onnx_format: False
batch_size: 10
batch_nums: 10
\ No newline at end of file
......@@ -197,12 +197,15 @@ def coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list):
with open(output, 'w') as f:
json.dump(results, f)
try:
coco_dt = coco_gt.loadRes(output)
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
except:
return [0.]
def _get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map):
......
......@@ -21,6 +21,7 @@ import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import random
import tempfile
import paddle
from paddle.fluid import core
from paddle.fluid import framework
......@@ -46,8 +47,8 @@ class AnalysisQuant(object):
eval_function=None,
data_loader=None,
save_dir='analysis_results',
checkpoint_name='analysis_checkpoint.pkl',
num_histogram_plots=10,
resume=False,
ptq_config=None):
"""
AnalysisQuant provides to analysis the sensitivity of each op in the model.
......@@ -61,7 +62,7 @@ class AnalysisQuant(object):
Generator or Dataloader provides calibrate data, and it could
return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information
checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing
resume(bool, optional): When break off while ananlyzing, could resume analysis program and load already analyzed information.
ptq_config(dict, optional): the args that can initialize PostTrainingQuantization
"""
......@@ -76,15 +77,24 @@ class AnalysisQuant(object):
self.save_dir = save_dir
self.eval_function = eval_function
self.quant_layer_names = []
self.checkpoint_name = os.path.join(save_dir, checkpoint_name)
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.quant_layer_metrics = {}
self.num_histogram_plots = num_histogram_plots
self.ptq_config = ptq_config
self.batch_nums = ptq_config[
'batch_nums'] if 'batch_nums' in ptq_config else 10
self.is_full_quantize = ptq_config[
'is_full_quantize'] if 'is_full_quantize' in ptq_config else False
self.onnx_format = ptq_config[
'onnx_format'] if 'onnx_format' in ptq_config else False
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
if self.onnx_format:
self.temp_root_path = tempfile.TemporaryDirectory(dir=self.save_dir)
self.temp_save_path = os.path.join(self.temp_root_path.name, "ptq")
if not os.path.exists(self.temp_save_path):
os.makedirs(self.temp_save_path)
devices = paddle.device.get_device().split(':')[0]
self.places = paddle.device._convert_to_place(devices)
......@@ -117,8 +127,19 @@ class AnalysisQuant(object):
params_filename=self.params_filename,
skip_tensor_list=None,
algo='avg', #fastest
onnx_format=self.onnx_format,
**self.ptq_config)
program = post_training_quantization.quantize()
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list)
_logger.info('After quantized, the accuracy of the model is: {}'.format(
......@@ -127,11 +148,14 @@ class AnalysisQuant(object):
# get quantized weight and act var name
self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name
self.quantized_act_var_name = post_training_quantization._quantized_act_var_name
self.support_quant_val_name_list = self.quantized_weight_var_name if not self.is_full_quantize else list(
self.quantized_act_var_name)
executor.close()
# load tobe_analyized_layer from checkpoint
if resume:
self.load_checkpoint()
self.tobe_analyized_layer = self.quantized_weight_var_name - set(
self.tobe_analyized_layer = set(self.support_quant_val_name_list) - set(
list(self.quant_layer_metrics.keys()))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
......@@ -194,8 +218,6 @@ class AnalysisQuant(object):
scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
......@@ -249,7 +271,7 @@ class AnalysisQuant(object):
for i, layer_name in enumerate(self.tobe_analyized_layer):
_logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.quantized_weight_var_name))
skip_list = copy.copy(list(self.support_quant_val_name_list))
skip_list.remove(layer_name)
executor = paddle.static.Executor(self.places)
......@@ -260,20 +282,33 @@ class AnalysisQuant(object):
model_filename=self.model_filename,
params_filename=self.params_filename,
skip_tensor_list=skip_list,
onnx_format=self.onnx_format,
algo='avg', #fastest
**self.ptq_config)
program = post_training_quantization.quantize()
_logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
executor.close()
_logger.info(
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name, quant_metric, self.base_metric -
quant_metric))
format(layer_name,
round(quant_metric, 4),
round(self.base_metric - quant_metric, 4)))
self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint()
if self.onnx_format:
self.temp_root_path.cleanup()
def get_weight_act_map(self, program, weight_names, persistable_var_names):
act_names = {}
......@@ -408,6 +443,7 @@ class AnalysisQuant(object):
model_dir=self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
onnx_format=self.onnx_format,
skip_tensor_list=skip_list,
**self.ptq_config)
program = post_training_quantization.quantize()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册