diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index b61f4acaee57543d51cd7aadb8163d164999c274..5f9b94ea82fcacc950f806803e0d04ec41c18ca2 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -65,6 +65,56 @@ def _all_persistable_var_names(program): return persistable_var_names +def _remove_unused_var_nodes(graph): + all_used_vars = set() + ops = graph.all_op_nodes() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } + graph.safe_remove_nodes(all_unused_vars) + return graph + + +def _remove_ctrl_vars(graph): + remove_ctr_vars = set() + for node in graph.all_var_nodes(): + if node.is_ctrl_var(): + remove_ctr_vars.add(node) + graph.safe_remove_nodes(remove_ctr_vars) + return graph + + +def _apply_pass(scope, + graph, + pass_name, + attrs=None, + attr_values=None, + debug=False): + ir_pass = core.get_pass(pass_name) + cpp_graph = graph.graph + if not cpp_graph.has('__param_scope__'): + cpp_graph.set_not_owned('__param_scope__', scope) + if attrs: + assert attr_values and len(attrs) == len( + attr_values), "Different number of pass attributes and their values." + for attr, value in zip(attrs, attr_values): + ir_pass.set(attr, value) + ir_pass.apply(cpp_graph) + if debug: + graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes()) + _remove_unused_var_nodes(graph) + return graph + + class PostTrainingQuantization(object): """ Utilizing post training quantization methon to quantize the FP32 model, @@ -89,6 +139,7 @@ class PostTrainingQuantization(object): weight_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', + optimize_model=False, is_use_cache_file=False, cache_dir="./temp_post_training"): ''' @@ -145,6 +196,14 @@ class PostTrainingQuantization(object): the fake ops in saving quantized model, and we save the scale obtained by post training quantization in fake ops. Compared to 'abs_max', the model accuracy is usually higher when it is 'channel_wise_abs_max'. + optimize_model(bool, optional): If set optimize_model as True, it applies + some passes to the model before quantization, and it supports + `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the + weights are quantized by tensor-wise method, which means the weights + scale for all channel are the same. However, if fuse + `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will + be different. In address this problem, fuse the pattern before + quantization. Default False. is_use_cache_file(bool, optional): If set is_use_cache_file as False, all temp data will be saved in memory. If set is_use_cache_file as True, it will save temp data to disk. When the fp32 model is complex or @@ -240,6 +299,7 @@ class PostTrainingQuantization(object): for op_type in self._quantizable_op_type: assert op_type in self._support_quantize_op_type, \ op_type + " is not supported for quantization." + self._optimize_model = optimize_model self._is_use_cache_file = is_use_cache_file self._cache_dir = cache_dir if self._is_use_cache_file and not os.path.exists(self._cache_dir): @@ -344,6 +404,10 @@ class PostTrainingQuantization(object): executor=self._executor, model_filename=self._model_filename, params_filename=self._params_filename) + + if self._optimize_model: + self._optimize_fp32_model() + feed_vars = [framework._get_var(str(var_name), self._program) \ for var_name in self._feed_list] self._data_loader = io.DataLoader.from_generator( @@ -358,6 +422,16 @@ class PostTrainingQuantization(object): self._data_loader.set_batch_generator( self._batch_generator, places=self._place) + def _optimize_fp32_model(self): + ''' + Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model. + ''' + _logger.info("Optimize FP32 model ...") + graph = IrGraph(core.Graph(self._program.desc), for_test=True) + graph = _remove_ctrl_vars(graph) + graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass') + self._program = graph.to_program() + def _collect_target_varnames(self): ''' Collect the variable names for sampling, and set activation diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 50085ed4a5b7aff66a9581e0d7f2415c9b46f631..864631ec27829e29aabb1a00a858cd0ce85e8389 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -239,7 +239,8 @@ class TestPostTrainingQuantization(unittest.TestCase): quantizable_op_type, algo="KL", is_full_quantize=False, - is_use_cache_file=False): + is_use_cache_file=False, + is_optimize_model=False): try: os.system("mkdir " + self.int8_model) except Exception as e: @@ -259,12 +260,14 @@ class TestPostTrainingQuantization(unittest.TestCase): algo=algo, quantizable_op_type=quantizable_op_type, is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model) def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, diff_threshold): + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -278,9 +281,9 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) - self.generate_quantized_model(model_cache_folder + "/model", - quantizable_op_type, algo, - is_full_quantize, is_use_cache_file) + self.generate_quantized_model( + model_cache_folder + "/model", quantizable_op_type, algo, + is_full_quantize, is_use_cache_file, is_optimize_model) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -316,9 +319,11 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ] is_full_quantize = False is_use_cache_file = False + is_optimize_model = True diff_threshold = 0.025 self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, diff_threshold) + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): @@ -335,10 +340,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ] is_full_quantize = False is_use_cache_file = False + is_optimize_model = False # The accuracy diff of post-traing quantization (abs_max) maybe bigger diff_threshold = 0.05 self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, diff_threshold) + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index 373a65018800a52d8d8de5373ad95dde21001614..a6c19b5e45a41ba8f30648befb44de5ad30d6fe8 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -28,9 +28,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): quantizable_op_type = ["conv2d", "mul"] is_full_quantize = False is_use_cache_file = False + is_optimize_model = False diff_threshold = 0.025 self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, diff_threshold) + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) if __name__ == '__main__':