提交 0cb50bb9 编写于 作者: Z Zhen Wang

avoid ce fails on windows.

上级 f5a37518
...@@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized')) arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops) self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type): def linear_fc_quant(self, quant_type, enable_ce=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -138,6 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -138,6 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -146,6 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -146,6 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce:
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -153,14 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -153,14 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self): def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max' self.linear_fc_quant('abs_max', enable_ce=True)
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self): def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max' self.linear_fc_quant('range_abs_max', enable_ce=True)
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type): def residual_block_quant(self, quant_type, enable_ce=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -175,6 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -175,6 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -183,6 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -183,6 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce:
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -190,16 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -190,16 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max' self.residual_block_quant('abs_max', enable_ce=True)
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self): def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max' self.residual_block_quant('range_abs_max', enable_ce=True)
self.residual_block_quant('range_abs_max')
class TestQuantizationFreezePass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type): def freeze_graph(self, use_cuda, seed, quant_type, enable_ce=False):
def build_program(main, startup, is_test): def build_program(main, startup, is_test):
main.random_seed = seed main.random_seed = seed
startup.random_seed = seed startup.random_seed = seed
...@@ -237,6 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -237,6 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in main_graph.all_op_nodes(): for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -266,7 +267,9 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -266,7 +267,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program, loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) if not enable_ce:
print('{}: {}'.format('loss' + dev_name + quant_type,
loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
with fluid.program_guard(quantized_test_program): with fluid.program_guard(quantized_test_program):
...@@ -281,6 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -281,6 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -294,11 +298,15 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -294,11 +298,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) if not enable_ce:
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) print('{}: {}'.format('test_loss1' + dev_name + quant_type,
test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type,
test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if not enable_ce:
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze))) np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type, print('{}: {}'.format('w_quant' + dev_name + quant_type,
...@@ -307,11 +315,13 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -307,11 +315,13 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test_int8' + dev_name + quant_type,
marked_nodes)
server_program_int8 = test_graph.to_program() server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file. # Save the 8-bit parameter and model file.
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
...@@ -325,12 +335,15 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -325,12 +335,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) if not enable_ce:
print('{}: {}'.format('w_8bit' + dev_name + quant_type,
np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze))) np.sum(w_freeze)))
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -347,20 +360,24 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -347,20 +360,24 @@ class TestQuantizationFreezePass(unittest.TestCase):
def test_freeze_graph_cuda_dynamic(self): def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='abs_max') self.freeze_graph(
True, seed=1, quant_type='abs_max', enable_ce=True)
def test_freeze_graph_cpu_dynamic(self): def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max') self.freeze_graph(
False, seed=2, quant_type='abs_max', enable_ce=True)
def test_freeze_graph_cuda_static(self): def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='range_abs_max') self.freeze_graph(
True, seed=1, quant_type='range_abs_max', enable_ce=True)
def test_freeze_graph_cpu_static(self): def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='range_abs_max') self.freeze_graph(
False, seed=2, quant_type='range_abs_max', enable_ce=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册