From 7781902e7901f086d5173bb6bb475b49bf64da38 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Tue, 19 May 2020 22:33:15 +0800 Subject: [PATCH] update post_quantization.py --- contrib/HumanSeg/utils/post_quantization.py | 197 +++++++++++++------- 1 file changed, 126 insertions(+), 71 deletions(-) diff --git a/contrib/HumanSeg/utils/post_quantization.py b/contrib/HumanSeg/utils/post_quantization.py index 00d61c80..f6d040e9 100644 --- a/contrib/HumanSeg/utils/post_quantization.py +++ b/contrib/HumanSeg/utils/post_quantization.py @@ -14,12 +14,14 @@ from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass -from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name +from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization +import utils.logging as logging import paddle.fluid as fluid import os - -import utils.logging as logging +import re +import numpy as np +import time class HumanSegPostTrainingQuantization(PostTrainingQuantization): @@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): fp32 model. It uses calibrate data to calculate the scale factor of quantized variables, and inserts fake quant/dequant op to obtain the quantized model. - Args: executor(fluid.Executor): The executor to load, run and save the quantized model. @@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): Returns: None ''' + self._support_activation_quantize_type = [ + 'range_abs_max', 'moving_average_abs_max', 'abs_max' + ] + self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] + self._support_algo_type = ['KL', 'abs_max', 'min_max'] + self._support_quantize_op_type = \ + list(set(QuantizationTransformPass._supported_quantizable_op_type + + AddQuantDequantPass._supported_quantizable_op_type)) + + # Check inputs + assert executor is not None, "The executor cannot be None." + assert batch_size > 0, "The batch_size should be greater than 0." + assert algo in self._support_algo_type, \ + "The algo should be KL, abs_max or min_max." + self._executor = executor self._dataset = dataset self._batch_size = batch_size @@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): self._algo = algo self._is_use_cache_file = is_use_cache_file self._cache_dir = cache_dir + self._activation_bits = 8 + self._weight_bits = 8 + self._activation_quantize_type = 'range_abs_max' + self._weight_quantize_type = 'channel_wise_abs_max' if self._is_use_cache_file and not os.path.exists(self._cache_dir): os.mkdir(self._cache_dir) - supported_quantizable_op_type = \ - QuantizationTransformPass._supported_quantizable_op_type + \ - AddQuantDequantPass._supported_quantizable_op_type if is_full_quantize: - self._quantizable_op_type = supported_quantizable_op_type + self._quantizable_op_type = self._support_quantize_op_type else: self._quantizable_op_type = quantizable_op_type for op_type in self._quantizable_op_type: - assert op_type in supported_quantizable_op_type + \ + assert op_type in self._support_quantize_op_type + \ AddQuantDequantPass._activation_type, \ op_type + " is not supported for quantization." @@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): self._fetch_list = list(outputs.values()) self._data_loader = None - self._op_real_in_out_name = _op_real_in_out_name + self._out_scale_op_list = _out_scale_op_list self._bit_length = 8 self._quantized_weight_var_name = set() self._quantized_act_var_name = set() self._sampling_data = {} - self._quantized_var_scale_factor = {} + self._quantized_var_kl_threshold = {} + self._quantized_var_min = {} + self._quantized_var_max = {} + self._quantized_var_abs_max = {} def quantize(self): ''' Quantize the fp32 model. Use calibrate data to calculate the scale factor of quantized variables, and inserts fake quant/dequant op to obtain the quantized model. - Args: None Returns: the program of quantized model. ''' - self._preprocess() - + self._load_model_data() + self._collect_target_varnames() + self._set_activation_persistable() + batch_ct = 0 + for data in self._data_loader(): + batch_ct += 1 + if self._batch_nums and batch_ct >= self._batch_nums: + break batch_id = 0 + logging.info("Start to run batch!") for data in self._data_loader(): + start = time.time() self._executor.run( program=self._program, feed=data, fetch_list=self._fetch_list, return_numpy=False) - self._sample_data(batch_id) - - if batch_id % 5 == 0: - logging.info("run batch: {}".format(batch_id)) + if self._algo == "KL": + self._sample_data(batch_id) + else: + self._sample_threshold() + end = time.time() + logging.debug( + '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format( + str(batch_id + 1), str(batch_ct), str(end - start))) batch_id += 1 if self._batch_nums and batch_id >= self._batch_nums: break - logging.info("all run batch: ".format(batch_id)) - logging.info("calculate scale factor ...") - self._calculate_scale_factor() - logging.info("update the program ...") - self._update_program() - - self._save_output_scale() + logging.info("All run batch: ".format(batch_id)) + self._reset_activation_persistable() + logging.info("Calculate scale factor ...") + if self._algo == "KL": + self._calculate_kl_threshold() + logging.info("Update the program ...") + if self._algo in ["KL", "abs_max"]: + self._update_program() + else: + self._save_input_threhold() + logging.info("Save ...") + self._save_output_threshold() + logging.info("Finish quant!") return self._program def save_quantized_model(self, save_model_path): ''' Save the quantized model to the disk. - Args: save_model_path(str): The path to save the quantized model Returns: @@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): params_filename='__params__', main_program=self._program) - def _preprocess(self): + def _load_model_data(self): ''' - Load model and set data loader, collect the variable names for sampling, - and set activation variables to be persistable. + Set data loader. ''' feed_vars = [fluid.framework._get_var(var.name, self._program) \ for var in self._feed_list] - self._data_loader = fluid.io.DataLoader.from_generator( feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) self._data_loader.set_sample_list_generator( self._dataset.generator(self._batch_size, drop_last=True), places=self._place) - # collect the variable names for sampling - persistable_var_names = [] - for var in self._program.list_vars(): - if var.persistable: - persistable_var_names.append(var.name) - - for op in self._program.global_block().ops: - op_type = op.type - if op_type in self._quantizable_op_type: - if op_type in ("conv2d", "depthwise_conv2d"): - self._quantized_act_var_name.add(op.input("Input")[0]) - self._quantized_weight_var_name.add(op.input("Filter")[0]) - self._quantized_act_var_name.add(op.output("Output")[0]) - elif op_type == "mul": - if self._is_input_all_not_persistable( - op, persistable_var_names): - op._set_attr("skip_quant", True) - logging.warning( - "Skip quant a mul op for two input variables are not persistable" - ) - else: - self._quantized_act_var_name.add(op.input("X")[0]) - self._quantized_weight_var_name.add(op.input("Y")[0]) - self._quantized_act_var_name.add(op.output("Out")[0]) - else: - # process other quantizable op type, the input must all not persistable - if self._is_input_all_not_persistable( - op, persistable_var_names): - input_output_name_list = self._op_real_in_out_name[ - op_type] - for input_name in input_output_name_list[0]: - for var_name in op.input(input_name): - self._quantized_act_var_name.add(var_name) - for output_name in input_output_name_list[1]: - for var_name in op.output(output_name): - self._quantized_act_var_name.add(var_name) + def _calculate_kl_threshold(self): + ''' + Calculate the KL threshold of quantized variables. + ''' + assert self._algo == "KL", "The algo should be KL to calculate kl threshold." + ct = 1 + # Abs_max threshold for weights + for var_name in self._quantized_weight_var_name: + start = time.time() + weight_data = self._sampling_data[var_name] + weight_threshold = None + if self._weight_quantize_type == "abs_max": + weight_threshold = np.max(np.abs(weight_data)) + elif self._weight_quantize_type == "channel_wise_abs_max": + weight_threshold = [] + for i in range(weight_data.shape[0]): + abs_max_value = np.max(np.abs(weight_data[i])) + weight_threshold.append(abs_max_value) + self._quantized_var_kl_threshold[var_name] = weight_threshold + end = time.time() + logging.debug( + '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'. + format( + str(ct), str(len(self._quantized_weight_var_name)), + str(end - start))) + ct += 1 - # set activation variables to be persistable, so can obtain - # the tensor data in sample_data - for var in self._program.list_vars(): - if var.name in self._quantized_act_var_name: - var.persistable = True + ct = 1 + # KL threshold for activations + if self._is_use_cache_file: + for var_name in self._quantized_act_var_name: + start = time.time() + sampling_data = [] + filenames = [f for f in os.listdir(self._cache_dir) \ + if re.match(var_name + '_[0-9]+.npy', f)] + for filename in filenames: + file_path = os.path.join(self._cache_dir, filename) + sampling_data.append(np.load(file_path)) + os.remove(file_path) + sampling_data = np.concatenate(sampling_data) + self._quantized_var_kl_threshold[var_name] = \ + self._get_kl_scaling_factor(np.abs(sampling_data)) + end = time.time() + logging.debug( + '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.' + .format( + str(ct), str(len(self._quantized_act_var_name)), + str(end - start))) + ct += 1 + else: + for var_name in self._quantized_act_var_name: + start = time.time() + self._sampling_data[var_name] = np.concatenate( + self._sampling_data[var_name]) + self._quantized_var_kl_threshold[var_name] = \ + self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name])) + end = time.time() + logging.debug( + '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.' + .format( + str(ct), str(len(self._quantized_act_var_name)), + str(end - start))) + ct += 1 -- GitLab