From b83d27ac7552cb9e751914968555404c8c47877a Mon Sep 17 00:00:00 2001 From: handiz <35895648+ZhangHandi@users.noreply.github.com> Date: Mon, 5 Sep 2022 17:10:05 +0800 Subject: [PATCH] fix bug in PostTrainingProgram for certain cases (#45616) * fix bug in PostTrainingProgram for certain cases --- .../post_training_quantization.py | 44 ++++++++++--------- .../slim/quantization/quantization_pass.py | 7 +-- ..._training_quantization_program_resnet50.py | 6 ++- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index eace86e71ae..4e37ba05b68 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -26,6 +26,7 @@ except: from inspect import isgeneratorfunction from .... import io from .... import core +from .... import reader from .... import framework from .... import unique_name from ....executor import global_scope, Executor @@ -141,7 +142,6 @@ class PostTrainingQuantization(object): is_use_cache_file=False, skip_tensor_list=None, same_scale_tensor_list=None, - scale_trainable=False, cache_dir=None, scale_dict=None, return_graph=False): @@ -231,7 +231,6 @@ class PostTrainingQuantization(object): `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will be different. In address this problem, fuse the pattern before quantization. Default False. - scale_trainable(bool, optional): whether scale can be train. is_use_cache_file(bool, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated. Returns: @@ -296,7 +295,7 @@ class PostTrainingQuantization(object): batch_generator, data_loader]), "The sample_generator, batch_generator " \ "and data_loader cannot be None in the same time." if data_loader is not None: - assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \ + assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction), reader.GeneratorLoader)), \ "data_loader only accepts `paddle.io.DataLoader` or Generator instance." assert batch_size > 0, "The batch_size should be greater than 0." assert algo in self._support_algo_type, \ @@ -366,9 +365,11 @@ class PostTrainingQuantization(object): self._quantized_threshold = {} self._same_scale_tensor_list = same_scale_tensor_list self._freeze_model = freeze_model - self._scale_trainable = scale_trainable self._scale_dict = scale_dict self._return_graph = return_graph + self.FLAG = False + if self._program is not None: + self.FLAG = True def quantize(self): ''' @@ -440,7 +441,8 @@ class PostTrainingQuantization(object): self._update_program() # save out_threshold for quantized ops. - self._save_output_threshold() + if not self.FLAG: + self._save_output_threshold() if any(op_type in self._quantizable_op_type for op_type in self._dynamic_quantize_op_type): @@ -1001,8 +1003,7 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types, - is_test=not self._scale_trainable) + quantizable_op_type=major_quantizable_op_types) else: transform_pass = QuantizationTransformPassV2( scope=self._scope, @@ -1011,8 +1012,7 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types, - is_test=not self._scale_trainable) + quantizable_op_type=major_quantizable_op_types) for sub_graph in graph.all_sub_graphs(): # Insert fake_quant/fake_dequantize op must in test graph, so @@ -1029,15 +1029,13 @@ class PostTrainingQuantization(object): add_quant_dequant_pass = AddQuantDequantPass( scope=self._scope, place=self._place, - quantizable_op_type=minor_quantizable_op_types, - is_test=not self._scale_trainable) + quantizable_op_type=minor_quantizable_op_types) else: add_quant_dequant_pass = AddQuantDequantPassV2( scope=self._scope, place=self._place, quantizable_op_type=minor_quantizable_op_types, - is_full_quantized=self._is_full_quantize, - is_test=not self._scale_trainable) + is_full_quantized=self._is_full_quantize) for sub_graph in graph.all_sub_graphs(): sub_graph._for_test = True @@ -1055,11 +1053,11 @@ class PostTrainingQuantization(object): max_scale = None tmp_tensor_list = [] for tensor_name in tensor_list: - if tensor_name not in scale_dict.keys(): - continue if '#' in tensor_name: real_tensor_name, opera, scalar = tensor_name.split( '#') + if real_tensor_name not in scale_dict.keys(): + continue if opera == '*': scale_dict[real_tensor_name] = float( scale_dict[real_tensor_name]) * float( @@ -1072,16 +1070,18 @@ class PostTrainingQuantization(object): real_tensor_name] if max_scale is None else max( max_scale, scale_dict[real_tensor_name]) else: + if tensor_name not in scale_dict.keys(): + continue max_scale = scale_dict[ tensor_name] if max_scale is None else max( max_scale, scale_dict[tensor_name]) for tensor_name in tensor_list: - if tensor_name not in scale_dict.keys(): - continue if '#' in tensor_name: real_tensor_name, opera, scalar = tensor_name.split( '#') + if real_tensor_name not in scale_dict.keys(): + continue if opera == '*': scale_dict[ real_tensor_name] = max_scale / float( @@ -1091,6 +1091,8 @@ class PostTrainingQuantization(object): real_tensor_name] = max_scale * float( scalar) else: + if tensor_name not in scale_dict.keys(): + continue scale_dict[tensor_name] = max_scale self._scale_dict = scale_dict @@ -1265,7 +1267,6 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization): is_use_cache_file=False, skip_tensor_list=None, same_scale_tensor_list=None, - scale_trainable=False, cache_dir=None, scale_dict=None, return_graph=True): @@ -1276,9 +1277,12 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization): activation_bits, weight_bits, activation_quantize_type, weight_quantize_type, onnx_format, freeze_model, optimize_model, is_use_cache_file, skip_tensor_list, - same_scale_tensor_list, scale_trainable, cache_dir, - scale_dict, return_graph) + same_scale_tensor_list, cache_dir, scale_dict, + return_graph) + self.FLAG = False self._program = program + if self._program is not None: + self.FLAG = True assert feed_list is not None, \ "Feed list should not be None." assert fetch_list is not None, \ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 6fdd84e3491..f8d950aa5e0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1470,9 +1470,10 @@ class OutScaleForTrainingPass(object): data_type = 'float64' if in_node.dtype() \ == core.VarDesc.VarType.FP64 else 'float32' try: - scale_node = graph._find_node_by_name( + graph._find_node_by_name( graph.all_var_nodes(), self._scale_name(in_node.name())) + continue except: scale_node = graph.create_persistable_node( name=self._scale_name(in_node.name()), @@ -1487,8 +1488,8 @@ class OutScaleForTrainingPass(object): scale_value = np.ones([1], dtype=data_type) else: scale_value = np.ones([1], dtype=data_type) - _init_var_node(scale_node, scale_value, self._scope, - self._place) + _init_var_node(scale_node, scale_value, self._scope, + self._place) ins = {'X': in_node} outs = {'OutScale': scale_node} diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py index b6af3cc449a..5854d40529d 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py @@ -186,7 +186,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization): ], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2'], [ 'test_scale_name_not_in_scale_dict1', - 'test_scale_name_not_in_scale_dict1' + 'test_scale_name_not_in_scale_dict2' + ], + [ + 'test_scale_name_not_in_scale_dict1#/#1', + 'test_scale_name_not_in_scale_dict2#/#1' ]] ptq = PostTrainingQuantizationProgram( executor=exe, -- GitLab