提交 0cb50bb9 编写于 作者: Z Zhen Wang

avoid ce fails on windows.

上级 f5a37518
......@@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
def linear_fc_quant(self, quant_type, enable_ce=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -138,6 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
if not enable_ce:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -146,6 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce:
val_marked_nodes = set()
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -153,14 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
self.linear_fc_quant('abs_max', enable_ce=True)
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
self.linear_fc_quant('range_abs_max', enable_ce=True)
def residual_block_quant(self, quant_type):
def residual_block_quant(self, quant_type, enable_ce=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -175,6 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
if not enable_ce:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -183,6 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not enable_ce:
val_marked_nodes = set()
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -190,16 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
self.residual_block_quant('abs_max', enable_ce=True)
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
self.residual_block_quant('range_abs_max', enable_ce=True)
class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type):
def freeze_graph(self, use_cuda, seed, quant_type, enable_ce=False):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
......@@ -237,6 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
if not enable_ce:
marked_nodes = set()
for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -266,7 +267,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
if not enable_ce:
print('{}: {}'.format('loss' + dev_name + quant_type,
loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
......@@ -281,6 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set()
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -294,11 +298,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
if not enable_ce:
print('{}: {}'.format('test_loss1' + dev_name + quant_type,
test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type,
test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if not enable_ce:
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type,
......@@ -307,11 +315,13 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set()
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
test_graph.draw('.', 'test_int8' + dev_name + quant_type,
marked_nodes)
server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file.
with fluid.scope_guard(scope):
......@@ -325,12 +335,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
if not enable_ce:
print('{}: {}'.format('w_8bit' + dev_name + quant_type,
np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
if not enable_ce:
marked_nodes = set()
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
......@@ -347,20 +360,24 @@ class TestQuantizationFreezePass(unittest.TestCase):
def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='abs_max')
self.freeze_graph(
True, seed=1, quant_type='abs_max', enable_ce=True)
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max')
self.freeze_graph(
False, seed=2, quant_type='abs_max', enable_ce=True)
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='range_abs_max')
self.freeze_graph(
True, seed=1, quant_type='range_abs_max', enable_ce=True)
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='range_abs_max')
self.freeze_graph(
False, seed=2, quant_type='range_abs_max', enable_ce=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册