From bd0a9fb7aa8cc368e0b2f20160d5d1bd75f9b8c2 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Thu, 27 Sep 2018 04:50:43 +0000 Subject: [PATCH] Update code, since one merged PR hidden the API. --- .../contrib/quantize/quantize_transpiler.py | 21 +++++++++++----- .../contrib/tests/test_quantize_transpiler.py | 24 +++++++------------ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/contrib/quantize/quantize_transpiler.py b/python/paddle/fluid/contrib/quantize/quantize_transpiler.py index 3967652d32..032d0353ea 100644 --- a/python/paddle/fluid/contrib/quantize/quantize_transpiler.py +++ b/python/paddle/fluid/contrib/quantize/quantize_transpiler.py @@ -183,7 +183,7 @@ class QuantizeTranspiler(object): block, idx + 1, quant_var, scale_var, quant_bits) dequanted_vars[block_id][name] = dequant_var # rename the forward op inputs - op.rename_input(name, dequant_var.name) + op._rename_input(name, dequant_var.name) def _transpile_backward(block, op): block_id = block.idx @@ -191,7 +191,7 @@ class QuantizeTranspiler(object): for name in op.input_arg_names: if name in dequanted_vars[block_id]: dequant_var = dequanted_vars[block_id][name] - op.rename_input(name, dequant_var.name) + op._rename_input(name, dequant_var.name) no_dequanted_input_vars = False if no_dequanted_input_vars: raise ValueError("There is no dequanted inputs for op %s." % @@ -262,7 +262,7 @@ class QuantizeTranspiler(object): scale_var = None for name in op.input_arg_names: if name in op_in_rename_map[block_id]: - op.rename_input(name, op_in_rename_map[block_id][name]) + op._rename_input(name, op_in_rename_map[block_id][name]) scale_v = var_scale_map[block_id][_original_var_name(name)] if _original_var_name(name) in persistable_vars: @@ -312,7 +312,8 @@ class QuantizeTranspiler(object): # input of the followed ops for name in op.input_arg_names: if name in op_out_rename_map[block_id]: - op.rename_input(name, op_out_rename_map[block_id][name]) + op._rename_input(name, + op_out_rename_map[block_id][name]) if op_type in self.fake_quant_op_types: in_arg_name = op.input('X')[0] @@ -378,10 +379,11 @@ class QuantizeTranspiler(object): if name not in input_map: int8_var = convert_to_int8(var) input_map[name] = int8_var.name - op.rename_input(name, input_map[name]) + op._rename_input(name, input_map[name]) self._remove_unused_var(program) def _remove_unused_var(self, program): + all_remove_vars = [] for block in program.blocks: args = [] for op in block.ops: @@ -389,9 +391,16 @@ class QuantizeTranspiler(object): args += op.output_arg_names args = list(set(args)) var_names = block.vars.keys() + sub_block_remove_vars = [] for var in var_names: if var not in args: - block._remove_var(var) + sub_block_remove_vars.append(var) + all_remove_vars.append(sub_block_remove_vars) + + remove_vars = [list(set(v)) for v in all_remove_vars] + for i, block in enumerate(program.blocks): + for v in remove_vars[i]: + block._remove_var(v) def _insert_quant_abs_max_op(self, block, idx, var, quant_bits): """Insert fake_quantize_abs_max op. diff --git a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py index 789fe33581..9af3a6c9fd 100644 --- a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py +++ b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py @@ -226,27 +226,19 @@ class TestQuantizeTranspiler(unittest.TestCase): with fluid.program_guard(test_program): test_data = next(test_reader()) - f_var = fluid.framework.get_var('conv2d_1.tmp_0', test_program) - w_var = fluid.framework.get_var('conv2d_1.w_0.quantized', - test_program) + w_var = fluid.framework._get_var('conv2d_1.w_0.quantized', + test_program) # Testing during training - test_loss1, f_v1, w_quant = exe.run( - program=test_program, - feed=feeder.feed(test_data), - fetch_list=[loss, f_var, w_var]) + test_loss1, w_quant = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss, w_var]) # Freeze program for inference, but the weight of fc/conv is still float type. quant_transpiler.freeze_program(test_program, place) - fv2 = fluid.framework.get_var('conv2d_1.tmp_0.dequantized', - test_program) - test_loss2, f_v2 = exe.run(program=test_program, - feed=feeder.feed(test_data), - fetch_list=[loss, fv2]) + test_loss2, = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss]) self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3) - self.assertTrue( - np.allclose( - f_v1, f_v2, rtol=1e-03, atol=1e-03), - "There is diff: " + str(f_v1) + "\n" + str(f_v2)) w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0') .get_tensor()) self.assertEqual(np.sum(w_freeze), np.sum(w_quant)) -- GitLab