提交 7ea5990c 编写于 作者: Z Zhen Wang 提交者: ceci3

update some details. test=develop

上级 bf807d69
...@@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized')) arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops) self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type, enable_ce=False): def linear_fc_quant(self, quant_type, for_ci=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -138,7 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -138,7 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -147,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -147,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce: if not for_ci:
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -155,12 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -155,12 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self): def test_linear_fc_quant_abs_max(self):
self.linear_fc_quant('abs_max', enable_ce=True) self.linear_fc_quant('abs_max', for_ci=True)
def test_linear_fc_quant_range_abs_max(self): def test_linear_fc_quant_range_abs_max(self):
self.linear_fc_quant('range_abs_max', enable_ce=True) self.linear_fc_quant('range_abs_max', for_ci=True)
def residual_block_quant(self, quant_type, enable_ce=False): def residual_block_quant(self, quant_type, for_ci=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -175,7 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -175,7 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -184,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -184,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce: if not for_ci:
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -192,14 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -192,14 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', enable_ce=True) self.residual_block_quant('abs_max', for_ci=True)
def test_residual_block_range_abs_max(self): def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', enable_ce=True) self.residual_block_quant('range_abs_max', for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type, enable_ce=False): def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False):
def build_program(main, startup, is_test): def build_program(main, startup, is_test):
main.random_seed = seed main.random_seed = seed
startup.random_seed = seed startup.random_seed = seed
...@@ -237,7 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -237,7 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in main_graph.all_op_nodes(): for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -267,7 +267,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -267,7 +267,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program, loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
if not enable_ce: if not for_ci:
print('{}: {}'.format('loss' + dev_name + quant_type, print('{}: {}'.format('loss' + dev_name + quant_type,
loss_v)) loss_v))
...@@ -284,7 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -284,7 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -298,7 +298,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -298,7 +298,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
if not enable_ce: if not for_ci:
print('{}: {}'.format('test_loss1' + dev_name + quant_type, print('{}: {}'.format('test_loss1' + dev_name + quant_type,
test_loss1)) test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type, print('{}: {}'.format('test_loss2' + dev_name + quant_type,
...@@ -306,7 +306,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -306,7 +306,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if not enable_ce: if not for_ci:
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze))) np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type, print('{}: {}'.format('w_quant' + dev_name + quant_type,
...@@ -315,7 +315,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -315,7 +315,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -335,7 +335,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -335,7 +335,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
if not enable_ce: if not for_ci:
print('{}: {}'.format('w_8bit' + dev_name + quant_type, print('{}: {}'.format('w_8bit' + dev_name + quant_type,
np.sum(w_8bit))) np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
...@@ -343,7 +343,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -343,7 +343,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
if not enable_ce: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -361,23 +361,22 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -361,23 +361,22 @@ class TestQuantizationFreezePass(unittest.TestCase):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='abs_max', enable_ce=True) True, seed=1, quant_type='abs_max', for_ci=True)
def test_freeze_graph_cpu_dynamic(self): def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True)
False, seed=2, quant_type='abs_max', enable_ce=True)
def test_freeze_graph_cuda_static(self): def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='range_abs_max', enable_ce=True) True, seed=1, quant_type='range_abs_max', for_ci=True)
def test_freeze_graph_cpu_static(self): def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
False, seed=2, quant_type='range_abs_max', enable_ce=True) False, seed=2, quant_type='range_abs_max', for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册