未验证 提交 43c5156b 编写于 作者: G Guanghua Yu 提交者: GitHub

update quant analysis (#1455)

上级 75d006bf
...@@ -37,6 +37,12 @@ def argsparser(): ...@@ -37,6 +37,12 @@ def argsparser():
type=str, type=str,
default='gpu', default='gpu',
help="which device used to compress.") help="which device used to compress.")
parser.add_argument(
'--resume',
type=bool,
default=False,
help="When break off while ananlyzing, could resume analysis program and load already analyzed information."
)
return parser return parser
...@@ -104,6 +110,7 @@ def main(): ...@@ -104,6 +110,7 @@ def main():
eval_function=eval_function, eval_function=eval_function,
data_loader=data_loader, data_loader=data_loader,
save_dir=config['save_dir'], save_dir=config['save_dir'],
resume=FLAGS.resume,
ptq_config=ptq_config) ptq_config=ptq_config)
# plot the boxplot of activations of quantizable weights # plot the boxplot of activations of quantizable weights
......
arch: YOLOv7 arch: YOLOv7
model_dir: ./yolov7.onnx model_dir: ./yolov7.onnx
save_dir: ./analysis_results save_dir: ./analysis_results
dataset_dir: /dataset/coco/ dataset_dir: dataset/coco/
val_image_dir: val2017 val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json val_anno_path: annotations/instances_val2017.json
# Small Dataset to accelerate analysis # Small Dataset to accelerate analysis
...@@ -15,5 +15,6 @@ PTQ: ...@@ -15,5 +15,6 @@ PTQ:
weight_quantize_type: 'abs_max' weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max' activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False is_full_quantize: False
onnx_format: False
batch_size: 10 batch_size: 10
batch_nums: 10 batch_nums: 10
\ No newline at end of file
...@@ -197,12 +197,15 @@ def coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list): ...@@ -197,12 +197,15 @@ def coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list):
with open(output, 'w') as f: with open(output, 'w') as f:
json.dump(results, f) json.dump(results, f)
coco_dt = coco_gt.loadRes(output) try:
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') coco_dt = coco_gt.loadRes(output)
coco_eval.evaluate() coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.accumulate() coco_eval.evaluate()
coco_eval.summarize() coco_eval.accumulate()
return coco_eval.stats coco_eval.summarize()
return coco_eval.stats
except:
return [0.]
def _get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map): def _get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map):
......
...@@ -21,6 +21,7 @@ import matplotlib.pyplot as plt ...@@ -21,6 +21,7 @@ import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
import numpy as np import numpy as np
import random import random
import tempfile
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
...@@ -46,8 +47,8 @@ class AnalysisQuant(object): ...@@ -46,8 +47,8 @@ class AnalysisQuant(object):
eval_function=None, eval_function=None,
data_loader=None, data_loader=None,
save_dir='analysis_results', save_dir='analysis_results',
checkpoint_name='analysis_checkpoint.pkl',
num_histogram_plots=10, num_histogram_plots=10,
resume=False,
ptq_config=None): ptq_config=None):
""" """
AnalysisQuant provides to analysis the sensitivity of each op in the model. AnalysisQuant provides to analysis the sensitivity of each op in the model.
...@@ -61,7 +62,7 @@ class AnalysisQuant(object): ...@@ -61,7 +62,7 @@ class AnalysisQuant(object):
Generator or Dataloader provides calibrate data, and it could Generator or Dataloader provides calibrate data, and it could
return a batch every time return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information save_dir(str, optional): the output dir that stores the analyzed information
checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing resume(bool, optional): When break off while ananlyzing, could resume analysis program and load already analyzed information.
ptq_config(dict, optional): the args that can initialize PostTrainingQuantization ptq_config(dict, optional): the args that can initialize PostTrainingQuantization
""" """
...@@ -76,15 +77,24 @@ class AnalysisQuant(object): ...@@ -76,15 +77,24 @@ class AnalysisQuant(object):
self.save_dir = save_dir self.save_dir = save_dir
self.eval_function = eval_function self.eval_function = eval_function
self.quant_layer_names = [] self.quant_layer_names = []
self.checkpoint_name = os.path.join(save_dir, checkpoint_name) self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.quant_layer_metrics = {} self.quant_layer_metrics = {}
self.num_histogram_plots = num_histogram_plots self.num_histogram_plots = num_histogram_plots
self.ptq_config = ptq_config self.ptq_config = ptq_config
self.batch_nums = ptq_config[ self.batch_nums = ptq_config[
'batch_nums'] if 'batch_nums' in ptq_config else 10 'batch_nums'] if 'batch_nums' in ptq_config else 10
self.is_full_quantize = ptq_config[
'is_full_quantize'] if 'is_full_quantize' in ptq_config else False
self.onnx_format = ptq_config[
'onnx_format'] if 'onnx_format' in ptq_config else False
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir) os.mkdir(self.save_dir)
if self.onnx_format:
self.temp_root_path = tempfile.TemporaryDirectory(dir=self.save_dir)
self.temp_save_path = os.path.join(self.temp_root_path.name, "ptq")
if not os.path.exists(self.temp_save_path):
os.makedirs(self.temp_save_path)
devices = paddle.device.get_device().split(':')[0] devices = paddle.device.get_device().split(':')[0]
self.places = paddle.device._convert_to_place(devices) self.places = paddle.device._convert_to_place(devices)
...@@ -117,8 +127,19 @@ class AnalysisQuant(object): ...@@ -117,8 +127,19 @@ class AnalysisQuant(object):
params_filename=self.params_filename, params_filename=self.params_filename,
skip_tensor_list=None, skip_tensor_list=None,
algo='avg', #fastest algo='avg', #fastest
onnx_format=self.onnx_format,
**self.ptq_config) **self.ptq_config)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
self.quant_metric = self.eval_function(executor, program, self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list) self.feed_list, self.fetch_list)
_logger.info('After quantized, the accuracy of the model is: {}'.format( _logger.info('After quantized, the accuracy of the model is: {}'.format(
...@@ -127,11 +148,14 @@ class AnalysisQuant(object): ...@@ -127,11 +148,14 @@ class AnalysisQuant(object):
# get quantized weight and act var name # get quantized weight and act var name
self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name self.quantized_weight_var_name = post_training_quantization._quantized_weight_var_name
self.quantized_act_var_name = post_training_quantization._quantized_act_var_name self.quantized_act_var_name = post_training_quantization._quantized_act_var_name
self.support_quant_val_name_list = self.quantized_weight_var_name if not self.is_full_quantize else list(
self.quantized_act_var_name)
executor.close() executor.close()
# load tobe_analyized_layer from checkpoint # load tobe_analyized_layer from checkpoint
self.load_checkpoint() if resume:
self.tobe_analyized_layer = self.quantized_weight_var_name - set( self.load_checkpoint()
self.tobe_analyized_layer = set(self.support_quant_val_name_list) - set(
list(self.quant_layer_metrics.keys())) list(self.quant_layer_metrics.keys()))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer)) self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
...@@ -194,8 +218,6 @@ class AnalysisQuant(object): ...@@ -194,8 +218,6 @@ class AnalysisQuant(object):
scope = global_scope() scope = global_scope()
graph = IrGraph(core.Graph(program.desc), for_test=False)
persistable_var_names = [] persistable_var_names = []
for var in program.list_vars(): for var in program.list_vars():
if var.persistable: if var.persistable:
...@@ -249,7 +271,7 @@ class AnalysisQuant(object): ...@@ -249,7 +271,7 @@ class AnalysisQuant(object):
for i, layer_name in enumerate(self.tobe_analyized_layer): for i, layer_name in enumerate(self.tobe_analyized_layer):
_logger.info('Checking {}/{} quant model: quant layer {}'.format( _logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name)) i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.quantized_weight_var_name)) skip_list = copy.copy(list(self.support_quant_val_name_list))
skip_list.remove(layer_name) skip_list.remove(layer_name)
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
...@@ -260,20 +282,33 @@ class AnalysisQuant(object): ...@@ -260,20 +282,33 @@ class AnalysisQuant(object):
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
skip_tensor_list=skip_list, skip_tensor_list=skip_list,
onnx_format=self.onnx_format,
algo='avg', #fastest algo='avg', #fastest
**self.ptq_config) **self.ptq_config)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
_logger.info('Evaluating...') _logger.info('Evaluating...')
if self.onnx_format:
post_training_quantization.save_quantized_model(
self.temp_save_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
program, _, _ = load_inference_model(
self.temp_save_path,
executor,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
quant_metric = self.eval_function(executor, program, self.feed_list, quant_metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list) self.fetch_list)
executor.close() executor.close()
_logger.info( _logger.info(
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}". "Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name, quant_metric, self.base_metric - format(layer_name,
quant_metric)) round(quant_metric, 4),
round(self.base_metric - quant_metric, 4)))
self.quant_layer_metrics[layer_name] = quant_metric self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint() self.save_checkpoint()
if self.onnx_format:
self.temp_root_path.cleanup()
def get_weight_act_map(self, program, weight_names, persistable_var_names): def get_weight_act_map(self, program, weight_names, persistable_var_names):
act_names = {} act_names = {}
...@@ -408,6 +443,7 @@ class AnalysisQuant(object): ...@@ -408,6 +443,7 @@ class AnalysisQuant(object):
model_dir=self.model_dir, model_dir=self.model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
onnx_format=self.onnx_format,
skip_tensor_list=skip_list, skip_tensor_list=skip_list,
**self.ptq_config) **self.ptq_config)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册