未验证 提交 1997011c 编写于 作者: W whs 提交者: GitHub

Refine the checking method in qat unittest (#1758)

上级 2de33a0a
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
from typing import List
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle import paddle
...@@ -43,8 +44,8 @@ class TestQuantAwareCase(StaticCase): ...@@ -43,8 +44,8 @@ class TestQuantAwareCase(StaticCase):
main_prog = paddle.static.default_main_program() main_prog = paddle.static.default_main_program()
val_prog = paddle.static.default_main_program().clone(for_test=True) val_prog = paddle.static.default_main_program().clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( place = paddle.CUDAPlace(
) else paddle.CPUPlace() 0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
...@@ -104,67 +105,87 @@ class TestQuantAwareCase(StaticCase): ...@@ -104,67 +105,87 @@ class TestQuantAwareCase(StaticCase):
train(main_prog) train(main_prog)
top1_1, top5_1 = test(main_prog) top1_1, top5_1 = test(main_prog)
ops_with_weights = [
'depthwise_conv2d',
'mul',
'conv2d',
]
ops_without_weights = [
'relu',
]
config = { config = {
'weight_quantize_type': 'channel_wise_abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], 'quantize_op_types': ops_with_weights + ops_without_weights,
} }
quant_train_prog = quant_aware(main_prog, place, config, for_test=False) quant_train_prog = quant_aware(main_prog, place, config, for_test=False)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
op_nums_1, quant_op_nums_1 = self.get_op_number(quant_eval_prog)
# test quant_aware op numbers # Step1: check the quantizers count in qat graph
self.assertTrue(op_nums_1 * 2 == quant_op_nums_1) quantizers_count_in_qat = self.count_op(quant_eval_prog,
['quantize_linear'])
ops_with_weights_count = self.count_op(quant_eval_prog,
ops_with_weights)
ops_without_weights_count = self.count_op(quant_eval_prog,
ops_without_weights)
self.assertEqual(ops_with_weights_count * 2 + ops_without_weights_count,
quantizers_count_in_qat)
with paddle.static.program_guard(quant_eval_prog):
paddle.static.save_inference_model("./models/mobilenet_qat", [
image, label
], [avg_cost, acc_top1, acc_top5], exe)
train(quant_train_prog) train(quant_train_prog)
convert_eval_prog = convert(quant_eval_prog, place, config) convert_eval_prog = convert(quant_eval_prog, place, config)
with paddle.static.program_guard(convert_eval_prog):
paddle.static.save_inference_model("./models/mobilenet_onnx", [
image, label
], [avg_cost, acc_top1, acc_top5], exe)
top1_2, top5_2 = test(convert_eval_prog) top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2)) print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number( # Step2: check the quantizers count in onnx graph
convert_eval_prog) quantizers_count = self.count_op(convert_eval_prog, ['quantize_linear'])
# test convert op numbers observers_count = self.count_op(quant_eval_prog,
self.assertTrue(convert_op_nums_1 + 25 == convert_quant_op_nums_1) ['moving_average_abs_max_scale'])
self.assertEqual(quantizers_count, ops_with_weights_count +
ops_without_weights_count + observers_count)
# Step3: check the quantization skipping
config['not_quant_pattern'] = ['last_fc'] config['not_quant_pattern'] = ['last_fc']
quant_prog_2 = quant_aware( skip_quant_prog = quant_aware(
main_prog, place, config=config, for_test=True) main_prog, place, config=config, for_test=True)
op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2) skip_quantizers_count_in_qat = self.count_op(skip_quant_prog,
convert_prog_2 = convert(quant_prog_2, place, config=config) ['quantize_linear'])
convert_op_nums_2, convert_quant_op_nums_2 = self.get_convert_op_number( skip_ops_with_weights_count = self.count_op(skip_quant_prog,
convert_prog_2) ops_with_weights)
skip_ops_without_weights_count = self.count_op(skip_quant_prog,
self.assertTrue(op_nums_1 == op_nums_2) ops_without_weights)
# test skip_quant self.assertEqual(skip_ops_without_weights_count,
self.assertTrue(quant_op_nums_1 - 2 == quant_op_nums_2) ops_without_weights_count)
self.assertTrue(convert_quant_op_nums_1 == convert_quant_op_nums_2) self.assertEqual(skip_ops_with_weights_count, ops_with_weights_count)
self.assertEqual(skip_quantizers_count_in_qat + 2,
def get_op_number(self, prog): quantizers_count_in_qat)
skip_quant_prog_onnx = convert(skip_quant_prog, place, config=config)
skip_quantizers_count_in_onnx = self.count_op(skip_quant_prog_onnx,
['quantize_linear'])
self.assertEqual(quantizers_count, skip_quantizers_count_in_onnx)
def count_op(self, prog, ops: List[str]):
graph = paddle.fluid.framework.IrGraph( graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=False) paddle.framework.core.Graph(prog.desc), for_test=False)
quant_op_nums = 0
op_nums = 0
for op in graph.all_op_nodes():
if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']:
op_nums += 1
elif op.name() == 'quantize_linear':
quant_op_nums += 1
return op_nums, quant_op_nums
def get_convert_op_number(self, prog):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=True)
quant_op_nums = 0
op_nums = 0 op_nums = 0
dequant_num = 0
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() not in ['quantize_linear', 'dequantize_linear']: if op.name() in ops:
op_nums += 1 op_nums += 1
elif op.name() == 'quantize_linear': return op_nums
quant_op_nums += 1
return op_nums, quant_op_nums
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册