提交 bd0a9fb7 编写于 作者: D Dang Qingqing

Update code, since one merged PR hidden the API.

上级 f7bd1761
......@@ -183,7 +183,7 @@ class QuantizeTranspiler(object):
block, idx + 1, quant_var, scale_var, quant_bits)
dequanted_vars[block_id][name] = dequant_var
# rename the forward op inputs
op.rename_input(name, dequant_var.name)
op._rename_input(name, dequant_var.name)
def _transpile_backward(block, op):
block_id = block.idx
......@@ -191,7 +191,7 @@ class QuantizeTranspiler(object):
for name in op.input_arg_names:
if name in dequanted_vars[block_id]:
dequant_var = dequanted_vars[block_id][name]
op.rename_input(name, dequant_var.name)
op._rename_input(name, dequant_var.name)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
......@@ -262,7 +262,7 @@ class QuantizeTranspiler(object):
scale_var = None
for name in op.input_arg_names:
if name in op_in_rename_map[block_id]:
op.rename_input(name, op_in_rename_map[block_id][name])
op._rename_input(name, op_in_rename_map[block_id][name])
scale_v = var_scale_map[block_id][_original_var_name(name)]
if _original_var_name(name) in persistable_vars:
......@@ -312,7 +312,8 @@ class QuantizeTranspiler(object):
# input of the followed ops
for name in op.input_arg_names:
if name in op_out_rename_map[block_id]:
op.rename_input(name, op_out_rename_map[block_id][name])
op._rename_input(name,
op_out_rename_map[block_id][name])
if op_type in self.fake_quant_op_types:
in_arg_name = op.input('X')[0]
......@@ -378,10 +379,11 @@ class QuantizeTranspiler(object):
if name not in input_map:
int8_var = convert_to_int8(var)
input_map[name] = int8_var.name
op.rename_input(name, input_map[name])
op._rename_input(name, input_map[name])
self._remove_unused_var(program)
def _remove_unused_var(self, program):
all_remove_vars = []
for block in program.blocks:
args = []
for op in block.ops:
......@@ -389,9 +391,16 @@ class QuantizeTranspiler(object):
args += op.output_arg_names
args = list(set(args))
var_names = block.vars.keys()
sub_block_remove_vars = []
for var in var_names:
if var not in args:
block._remove_var(var)
sub_block_remove_vars.append(var)
all_remove_vars.append(sub_block_remove_vars)
remove_vars = [list(set(v)) for v in all_remove_vars]
for i, block in enumerate(program.blocks):
for v in remove_vars[i]:
block._remove_var(v)
def _insert_quant_abs_max_op(self, block, idx, var, quant_bits):
"""Insert fake_quantize_abs_max op.
......
......@@ -226,27 +226,19 @@ class TestQuantizeTranspiler(unittest.TestCase):
with fluid.program_guard(test_program):
test_data = next(test_reader())
f_var = fluid.framework.get_var('conv2d_1.tmp_0', test_program)
w_var = fluid.framework.get_var('conv2d_1.w_0.quantized',
test_program)
w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
test_program)
# Testing during training
test_loss1, f_v1, w_quant = exe.run(
program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, f_var, w_var])
test_loss1, w_quant = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, w_var])
# Freeze program for inference, but the weight of fc/conv is still float type.
quant_transpiler.freeze_program(test_program, place)
fv2 = fluid.framework.get_var('conv2d_1.tmp_0.dequantized',
test_program)
test_loss2, f_v2 = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, fv2])
test_loss2, = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-3)
self.assertTrue(
np.allclose(
f_v1, f_v2, rtol=1e-03, atol=1e-03),
"There is diff: " + str(f_v1) + "\n" + str(f_v2))
w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0')
.get_tensor())
self.assertEqual(np.sum(w_freeze), np.sum(w_quant))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册