未验证 提交 1997011c 编写于 作者: W whs 提交者: GitHub

Refine the checking method in qat unittest (#1758)

上级 2de33a0a
......@@ -13,6 +13,7 @@
# limitations under the License.
import sys
from typing import List
sys.path.append("../")
import unittest
import paddle
......@@ -43,8 +44,8 @@ class TestQuantAwareCase(StaticCase):
main_prog = paddle.static.default_main_program()
val_prog = paddle.static.default_main_program().clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
......@@ -104,67 +105,87 @@ class TestQuantAwareCase(StaticCase):
train(main_prog)
top1_1, top5_1 = test(main_prog)
ops_with_weights = [
'depthwise_conv2d',
'mul',
'conv2d',
]
ops_without_weights = [
'relu',
]
config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'],
'quantize_op_types': ops_with_weights + ops_without_weights,
}
quant_train_prog = quant_aware(main_prog, place, config, for_test=False)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
op_nums_1, quant_op_nums_1 = self.get_op_number(quant_eval_prog)
# test quant_aware op numbers
self.assertTrue(op_nums_1 * 2 == quant_op_nums_1)
# Step1: check the quantizers count in qat graph
quantizers_count_in_qat = self.count_op(quant_eval_prog,
['quantize_linear'])
ops_with_weights_count = self.count_op(quant_eval_prog,
ops_with_weights)
ops_without_weights_count = self.count_op(quant_eval_prog,
ops_without_weights)
self.assertEqual(ops_with_weights_count * 2 + ops_without_weights_count,
quantizers_count_in_qat)
with paddle.static.program_guard(quant_eval_prog):
paddle.static.save_inference_model("./models/mobilenet_qat", [
image, label
], [avg_cost, acc_top1, acc_top5], exe)
train(quant_train_prog)
convert_eval_prog = convert(quant_eval_prog, place, config)
with paddle.static.program_guard(convert_eval_prog):
paddle.static.save_inference_model("./models/mobilenet_onnx", [
image, label
], [avg_cost, acc_top1, acc_top5], exe)
top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number(
convert_eval_prog)
# test convert op numbers
self.assertTrue(convert_op_nums_1 + 25 == convert_quant_op_nums_1)
# Step2: check the quantizers count in onnx graph
quantizers_count = self.count_op(convert_eval_prog, ['quantize_linear'])
observers_count = self.count_op(quant_eval_prog,
['moving_average_abs_max_scale'])
self.assertEqual(quantizers_count, ops_with_weights_count +
ops_without_weights_count + observers_count)
# Step3: check the quantization skipping
config['not_quant_pattern'] = ['last_fc']
quant_prog_2 = quant_aware(
skip_quant_prog = quant_aware(
main_prog, place, config=config, for_test=True)
op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2)
convert_prog_2 = convert(quant_prog_2, place, config=config)
convert_op_nums_2, convert_quant_op_nums_2 = self.get_convert_op_number(
convert_prog_2)
self.assertTrue(op_nums_1 == op_nums_2)
# test skip_quant
self.assertTrue(quant_op_nums_1 - 2 == quant_op_nums_2)
self.assertTrue(convert_quant_op_nums_1 == convert_quant_op_nums_2)
def get_op_number(self, prog):
skip_quantizers_count_in_qat = self.count_op(skip_quant_prog,
['quantize_linear'])
skip_ops_with_weights_count = self.count_op(skip_quant_prog,
ops_with_weights)
skip_ops_without_weights_count = self.count_op(skip_quant_prog,
ops_without_weights)
self.assertEqual(skip_ops_without_weights_count,
ops_without_weights_count)
self.assertEqual(skip_ops_with_weights_count, ops_with_weights_count)
self.assertEqual(skip_quantizers_count_in_qat + 2,
quantizers_count_in_qat)
skip_quant_prog_onnx = convert(skip_quant_prog, place, config=config)
skip_quantizers_count_in_onnx = self.count_op(skip_quant_prog_onnx,
['quantize_linear'])
self.assertEqual(quantizers_count, skip_quantizers_count_in_onnx)
def count_op(self, prog, ops: List[str]):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=False)
quant_op_nums = 0
op_nums = 0
for op in graph.all_op_nodes():
if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']:
op_nums += 1
elif op.name() == 'quantize_linear':
quant_op_nums += 1
return op_nums, quant_op_nums
def get_convert_op_number(self, prog):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=True)
quant_op_nums = 0
op_nums = 0
dequant_num = 0
for op in graph.all_op_nodes():
if op.name() not in ['quantize_linear', 'dequantize_linear']:
if op.name() in ops:
op_nums += 1
elif op.name() == 'quantize_linear':
quant_op_nums += 1
return op_nums, quant_op_nums
return op_nums
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册