提交 bd0a9fb7 编写于 作者: D Dang Qingqing

Update code, since one merged PR hidden the API.

上级 f7bd1761
...@@ -183,7 +183,7 @@ class QuantizeTranspiler(object): ...@@ -183,7 +183,7 @@ class QuantizeTranspiler(object):
block, idx + 1, quant_var, scale_var, quant_bits) block, idx + 1, quant_var, scale_var, quant_bits)
dequanted_vars[block_id][name] = dequant_var dequanted_vars[block_id][name] = dequant_var
# rename the forward op inputs # rename the forward op inputs
op.rename_input(name, dequant_var.name) op._rename_input(name, dequant_var.name)
def _transpile_backward(block, op): def _transpile_backward(block, op):
block_id = block.idx block_id = block.idx
...@@ -191,7 +191,7 @@ class QuantizeTranspiler(object): ...@@ -191,7 +191,7 @@ class QuantizeTranspiler(object):
for name in op.input_arg_names: for name in op.input_arg_names:
if name in dequanted_vars[block_id]: if name in dequanted_vars[block_id]:
dequant_var = dequanted_vars[block_id][name] dequant_var = dequanted_vars[block_id][name]
op.rename_input(name, dequant_var.name) op._rename_input(name, dequant_var.name)
no_dequanted_input_vars = False no_dequanted_input_vars = False
if no_dequanted_input_vars: if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." % raise ValueError("There is no dequanted inputs for op %s." %
...@@ -262,7 +262,7 @@ class QuantizeTranspiler(object): ...@@ -262,7 +262,7 @@ class QuantizeTranspiler(object):
scale_var = None scale_var = None
for name in op.input_arg_names: for name in op.input_arg_names:
if name in op_in_rename_map[block_id]: if name in op_in_rename_map[block_id]:
op.rename_input(name, op_in_rename_map[block_id][name]) op._rename_input(name, op_in_rename_map[block_id][name])
scale_v = var_scale_map[block_id][_original_var_name(name)] scale_v = var_scale_map[block_id][_original_var_name(name)]
if _original_var_name(name) in persistable_vars: if _original_var_name(name) in persistable_vars:
...@@ -312,7 +312,8 @@ class QuantizeTranspiler(object): ...@@ -312,7 +312,8 @@ class QuantizeTranspiler(object):
# input of the followed ops # input of the followed ops
for name in op.input_arg_names: for name in op.input_arg_names:
if name in op_out_rename_map[block_id]: if name in op_out_rename_map[block_id]:
op.rename_input(name, op_out_rename_map[block_id][name]) op._rename_input(name,
op_out_rename_map[block_id][name])
if op_type in self.fake_quant_op_types: if op_type in self.fake_quant_op_types:
in_arg_name = op.input('X')[0] in_arg_name = op.input('X')[0]
...@@ -378,10 +379,11 @@ class QuantizeTranspiler(object): ...@@ -378,10 +379,11 @@ class QuantizeTranspiler(object):
if name not in input_map: if name not in input_map:
int8_var = convert_to_int8(var) int8_var = convert_to_int8(var)
input_map[name] = int8_var.name input_map[name] = int8_var.name
op.rename_input(name, input_map[name]) op._rename_input(name, input_map[name])
self._remove_unused_var(program) self._remove_unused_var(program)
def _remove_unused_var(self, program): def _remove_unused_var(self, program):
all_remove_vars = []
for block in program.blocks: for block in program.blocks:
args = [] args = []
for op in block.ops: for op in block.ops:
...@@ -389,9 +391,16 @@ class QuantizeTranspiler(object): ...@@ -389,9 +391,16 @@ class QuantizeTranspiler(object):
args += op.output_arg_names args += op.output_arg_names
args = list(set(args)) args = list(set(args))
var_names = block.vars.keys() var_names = block.vars.keys()
sub_block_remove_vars = []
for var in var_names: for var in var_names:
if var not in args: if var not in args:
block._remove_var(var) sub_block_remove_vars.append(var)
all_remove_vars.append(sub_block_remove_vars)
remove_vars = [list(set(v)) for v in all_remove_vars]
for i, block in enumerate(program.blocks):
for v in remove_vars[i]:
block._remove_var(v)
def _insert_quant_abs_max_op(self, block, idx, var, quant_bits): def _insert_quant_abs_max_op(self, block, idx, var, quant_bits):
"""Insert fake_quantize_abs_max op. """Insert fake_quantize_abs_max op.
......
...@@ -226,27 +226,19 @@ class TestQuantizeTranspiler(unittest.TestCase): ...@@ -226,27 +226,19 @@ class TestQuantizeTranspiler(unittest.TestCase):
with fluid.program_guard(test_program): with fluid.program_guard(test_program):
test_data = next(test_reader()) test_data = next(test_reader())
f_var = fluid.framework.get_var('conv2d_1.tmp_0', test_program) w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
w_var = fluid.framework.get_var('conv2d_1.w_0.quantized',
test_program) test_program)
# Testing during training # Testing during training
test_loss1, f_v1, w_quant = exe.run( test_loss1, w_quant = exe.run(program=test_program,
program=test_program,
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss, f_var, w_var]) fetch_list=[loss, w_var])
# Freeze program for inference, but the weight of fc/conv is still float type. # Freeze program for inference, but the weight of fc/conv is still float type.
quant_transpiler.freeze_program(test_program, place) quant_transpiler.freeze_program(test_program, place)
fv2 = fluid.framework.get_var('conv2d_1.tmp_0.dequantized', test_loss2, = exe.run(program=test_program,
test_program)
test_loss2, f_v2 = exe.run(program=test_program,
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss, fv2]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3)
self.assertTrue(
np.allclose(
f_v1, f_v2, rtol=1e-03, atol=1e-03),
"There is diff: " + str(f_v1) + "\n" + str(f_v2))
w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0') w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0')
.get_tensor()) .get_tensor())
self.assertEqual(np.sum(w_freeze), np.sum(w_quant)) self.assertEqual(np.sum(w_freeze), np.sum(w_quant))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册