diff --git a/python/paddle/fluid/contrib/quantize/quantize_transpiler.py b/python/paddle/fluid/contrib/quantize/quantize_transpiler.py index 3967652d324190dd5395b845c65d464588b4d310..032d0353ea6d80c4356ea9a9886ea59c48feec7a 100644 --- a/python/paddle/fluid/contrib/quantize/quantize_transpiler.py +++ b/python/paddle/fluid/contrib/quantize/quantize_transpiler.py @@ -183,7 +183,7 @@ class QuantizeTranspiler(object): block, idx + 1, quant_var, scale_var, quant_bits) dequanted_vars[block_id][name] = dequant_var # rename the forward op inputs - op.rename_input(name, dequant_var.name) + op._rename_input(name, dequant_var.name) def _transpile_backward(block, op): block_id = block.idx @@ -191,7 +191,7 @@ class QuantizeTranspiler(object): for name in op.input_arg_names: if name in dequanted_vars[block_id]: dequant_var = dequanted_vars[block_id][name] - op.rename_input(name, dequant_var.name) + op._rename_input(name, dequant_var.name) no_dequanted_input_vars = False if no_dequanted_input_vars: raise ValueError("There is no dequanted inputs for op %s." % @@ -262,7 +262,7 @@ class QuantizeTranspiler(object): scale_var = None for name in op.input_arg_names: if name in op_in_rename_map[block_id]: - op.rename_input(name, op_in_rename_map[block_id][name]) + op._rename_input(name, op_in_rename_map[block_id][name]) scale_v = var_scale_map[block_id][_original_var_name(name)] if _original_var_name(name) in persistable_vars: @@ -312,7 +312,8 @@ class QuantizeTranspiler(object): # input of the followed ops for name in op.input_arg_names: if name in op_out_rename_map[block_id]: - op.rename_input(name, op_out_rename_map[block_id][name]) + op._rename_input(name, + op_out_rename_map[block_id][name]) if op_type in self.fake_quant_op_types: in_arg_name = op.input('X')[0] @@ -378,10 +379,11 @@ class QuantizeTranspiler(object): if name not in input_map: int8_var = convert_to_int8(var) input_map[name] = int8_var.name - op.rename_input(name, input_map[name]) + op._rename_input(name, input_map[name]) self._remove_unused_var(program) def _remove_unused_var(self, program): + all_remove_vars = [] for block in program.blocks: args = [] for op in block.ops: @@ -389,9 +391,16 @@ class QuantizeTranspiler(object): args += op.output_arg_names args = list(set(args)) var_names = block.vars.keys() + sub_block_remove_vars = [] for var in var_names: if var not in args: - block._remove_var(var) + sub_block_remove_vars.append(var) + all_remove_vars.append(sub_block_remove_vars) + + remove_vars = [list(set(v)) for v in all_remove_vars] + for i, block in enumerate(program.blocks): + for v in remove_vars[i]: + block._remove_var(v) def _insert_quant_abs_max_op(self, block, idx, var, quant_bits): """Insert fake_quantize_abs_max op. diff --git a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py index 789fe3358184e30110ba362e9e5539ff9be1e171..9af3a6c9fda121d411a8a19f3928238be84fe8a6 100644 --- a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py +++ b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py @@ -226,27 +226,19 @@ class TestQuantizeTranspiler(unittest.TestCase): with fluid.program_guard(test_program): test_data = next(test_reader()) - f_var = fluid.framework.get_var('conv2d_1.tmp_0', test_program) - w_var = fluid.framework.get_var('conv2d_1.w_0.quantized', - test_program) + w_var = fluid.framework._get_var('conv2d_1.w_0.quantized', + test_program) # Testing during training - test_loss1, f_v1, w_quant = exe.run( - program=test_program, - feed=feeder.feed(test_data), - fetch_list=[loss, f_var, w_var]) + test_loss1, w_quant = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss, w_var]) # Freeze program for inference, but the weight of fc/conv is still float type. quant_transpiler.freeze_program(test_program, place) - fv2 = fluid.framework.get_var('conv2d_1.tmp_0.dequantized', - test_program) - test_loss2, f_v2 = exe.run(program=test_program, - feed=feeder.feed(test_data), - fetch_list=[loss, fv2]) + test_loss2, = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss]) self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3) - self.assertTrue( - np.allclose( - f_v1, f_v2, rtol=1e-03, atol=1e-03), - "There is diff: " + str(f_v1) + "\n" + str(f_v2)) w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0') .get_tensor()) self.assertEqual(np.sum(w_freeze), np.sum(w_quant))