From fad86fc1fbbdc73bb6718a83b90d5a8e6181161a Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Mon, 3 Aug 2020 17:49:14 +0800 Subject: [PATCH] Fix preprocess_func quant_aware convert issue (#401) --- paddleslim/quant/quanter.py | 93 ++++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 462f1c75..95ee7a04 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import copy +import json import logging import paddle @@ -33,8 +35,7 @@ from ..common import get_logger _logger = get_logger(__name__, level=logging.INFO) WEIGHT_QUANTIZATION_TYPES = [ - 'abs_max', 'channel_wise_abs_max', 'range_abs_max', - 'moving_average_abs_max' + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max' ] WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max'] @@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [ 'leaky_relu' ] +VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model' + _quant_config_default = { # weight quantize type, default is 'channel_wise_abs_max' 'weight_quantize_type': 'channel_wise_abs_max', @@ -81,6 +84,18 @@ _quant_config_default = { } +def load_dict(): + with open(VARS_MAPPING_TABLE, 'r') as file: + data = file.read() + data = json.loads(data) + return data + + +def save_dict(table): + with open(VARS_MAPPING_TABLE, 'w') as file: + file.write(json.dumps(table)) + + def _parse_configs(user_config): """ check if user's configs are valid. @@ -267,6 +282,15 @@ def quant_aware(program, scope=scope, place=place, moving_rate=config['moving_rate']) out_scale_training_pass.apply(main_graph) + if (weight_preprocess_func is not None or + act_preprocess_func is not None) and not for_test: + _logger.info( + "When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase." + ) + _logger.info("The mapping table is saved as '{}'.".format( + VARS_MAPPING_TABLE)) + save_dict(main_graph.out_node_mapping_table) + if for_test: quant_program = main_graph.to_program() else: @@ -274,27 +298,28 @@ def quant_aware(program, return quant_program -def quant_post_static(executor, - model_dir, - quantize_model_path, - batch_generator=None, - sample_generator=None, - model_filename=None, - params_filename=None, - save_model_filename='__model__', - save_params_filename='__params__', - batch_size=16, - batch_nums=None, - scope=None, - algo='KL', - quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], - is_full_quantize=False, - weight_bits=8, - activation_bits=8, - activation_quantize_type='range_abs_max', - weight_quantize_type='channel_wise_abs_max', - is_use_cache_file=False, - cache_dir="./temp_post_training"): +def quant_post_static( + executor, + model_dir, + quantize_model_path, + batch_generator=None, + sample_generator=None, + model_filename=None, + params_filename=None, + save_model_filename='__model__', + save_params_filename='__params__', + batch_size=16, + batch_nums=None, + scope=None, + algo='KL', + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + is_full_quantize=False, + weight_bits=8, + activation_bits=8, + activation_quantize_type='range_abs_max', + weight_quantize_type='channel_wise_abs_max', + is_use_cache_file=False, + cache_dir="./temp_post_training"): """ The function utilizes static post training quantization method to quantize the fp32 model. It uses calibrate data to calculate the @@ -381,6 +406,7 @@ def quant_post_static(executor, model_filename=save_model_filename, params_filename=save_params_filename) + # We have changed the quant_post to quant_post_static. # For compatibility, we keep quant_post api for now, and it will be # deprecated in the future. @@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False): activation_bits=config['activation_bits'], weight_quantize_type=config['weight_quantize_type']) + if os.path.exists(VARS_MAPPING_TABLE): + test_graph.out_node_mapping_table = load_dict() + freeze_pass.apply(test_graph) freezed_program = test_graph.to_program() @@ -451,14 +480,14 @@ def convert(program, place, config=None, scope=None, save_int8=False): def quant_post_dynamic(model_dir, - save_model_dir, - model_filename=None, - params_filename=None, - save_model_filename=None, - save_params_filename=None, - quantizable_op_type=["conv2d", "mul"], - weight_bits=8, - generate_test_model=False): + save_model_dir, + model_filename=None, + params_filename=None, + save_model_filename=None, + save_params_filename=None, + quantizable_op_type=["conv2d", "mul"], + weight_bits=8, + generate_test_model=False): ''' The function utilizes static post training quantization method to quantize the fp32 model. In details, it quantizes the weight of some @@ -517,4 +546,4 @@ def quant_post_dynamic(model_dir, # We have changed the quant_post_only_weight to quant_post_dynamic. # For compatibility, we keep quant_post_only_weight api for now, # and it will be deprecated in the future. -quant_post_only_weight = quant_post_dynamic \ No newline at end of file +quant_post_only_weight = quant_post_dynamic -- GitLab