diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index c9614a1fb7770a7273e5f675380b635a1f8fd16c..c36cd1f74e6682050d230a176b815f1388619afd 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -435,6 +435,8 @@ class QuantizationTransformPass(object): if op.name() in self._quantizable_ops or \ op.name() in self._quantizable_grad_ops: _quant_preprocess(op) + # Insert mapping table to solve the problem in saving inference model. + graph.out_node_mapping_table = dict() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: for op in ops: @@ -853,6 +855,7 @@ class QuantizationTransformPass(object): shape=var_node.shape(), dtype='float32') out_node = func(in_node) + graph.out_node_mapping_table[out_node.name] = var_node.name() # loss shape must be 1 when minimize loss = mean(out_node) if not graph._for_test: @@ -1037,6 +1040,10 @@ class QuantizationFreezePass(object): op_name = op_node.name() if op_name in self._fake_quant_op_names: input_arg_name = op_node.input('X')[0] + if hasattr(graph, 'out_node_mapping_table'): + if input_arg_name in graph.out_node_mapping_table.keys(): + input_arg_name = graph.out_node_mapping_table[ + input_arg_name] if input_arg_name in persistable_vars: if self._weight_quantize_type == 'abs_max': param = self._load_var(input_arg_name) diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index ac4235d2e17936bd5b93fc85820b8f93361332c0..e85e8ae15bffd45184442046be0bfa8a192775ae 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -298,7 +298,6 @@ list(REMOVE_ITEM TEST_OPS #TODO(wanghaoshuang): Fix this unitest failed on GCC8. LIST(REMOVE_ITEM TEST_OPS test_auto_pruning) LIST(REMOVE_ITEM TEST_OPS test_filter_pruning) -LIST(REMOVE_ITEM TEST_OPS test_user_defined_quantization) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() diff --git a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py index 6f8d84a20a6372f967327c927590325d1c61dfbc..0c8f5cdd84cd486899cb892999b103cab72bfebb 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py +++ b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py @@ -14,6 +14,7 @@ import os import unittest +import json import random import numpy as np import six @@ -109,6 +110,16 @@ class TestUserDefinedQuantization(unittest.TestCase): def get_optimizer(): return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) + def load_dict(): + with open('mapping_table_for_saving_inference_model', 'r') as file: + data = file.read() + data = json.loads(data) + return data + + def save_dict(Dict): + with open('mapping_table_for_saving_inference_model', 'w') as file: + file.write(json.dumps(Dict)) + random.seed(0) np.random.seed(0) @@ -151,6 +162,7 @@ class TestUserDefinedQuantization(unittest.TestCase): executor=exe) test_transform_pass.apply(test_graph) + save_dict(test_graph.out_node_mapping_table) add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place) add_quant_dequant_pass.apply(main_graph) @@ -182,6 +194,21 @@ class TestUserDefinedQuantization(unittest.TestCase): feed=feeder.feed(data), fetch_list=[loss]) + out_scale_infer_pass = OutScaleForInferencePass(scope=scope) + out_scale_infer_pass.apply(test_graph) + + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_bits=8, + activation_bits=8, + weight_quantize_type=weight_quant_type) + + mapping_table = load_dict() + test_graph.out_node_mapping_table = mapping_table + if act_quantize_func == None and weight_quantize_func == None: + freeze_pass.apply(test_graph) + def test_act_preprocess_cuda(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard():