From 7ee31e72ed050ce5628f397fcd1d603f9aa7501a Mon Sep 17 00:00:00 2001 From: Zuza Gawrysiak Date: Tue, 4 Apr 2023 15:28:09 +0200 Subject: [PATCH] Rewrite quant2_int8_nlp_comparison test (#51995) * Correct ernie int8 test to use new QAT process * Add comment * Fix code style * Fix string formatting * Fix cmake files --------- Co-authored-by: wozna --- test/quantization/CMakeLists.txt | 3 +- .../quant2_int8_nlp_comparison.py | 217 ++++++++---------- 2 files changed, 102 insertions(+), 118 deletions(-) diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt index 492b4c60d92..eb17fac27e0 100644 --- a/test/quantization/CMakeLists.txt +++ b/test/quantization/CMakeLists.txt @@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN) set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} 114f38804a3ef8c45e7259e68bbd838b) - set(QUANT2_ERNIE_OPS_TO_QUANTIZE - "fc,reshape2,transpose2,matmul,elementwise_add,slice") + set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fused_matmul,matmul,matmul_v2,slice") inference_quant2_int8_nlp_test( test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} diff --git a/test/quantization/quant2_int8_nlp_comparison.py b/test/quantization/quant2_int8_nlp_comparison.py index 459046d1667..e0ee817e92c 100644 --- a/test/quantization/quant2_int8_nlp_comparison.py +++ b/test/quantization/quant2_int8_nlp_comparison.py @@ -14,7 +14,6 @@ import argparse import logging -import os import sys import time import unittest @@ -22,9 +21,8 @@ import unittest import numpy as np import paddle -from paddle.fluid.framework import IrGraph -from paddle.framework import core -from paddle.static.quantization import Quant2Int8MkldnnPass +from paddle import fluid +from paddle.inference import Config, create_predictor paddle.enable_static() @@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): buffer_np = np.array(buffer_i).astype("int64") buffer_np.shape = tuple(shape_np) buffers.append(buffer_np) - label = labels_lines[i] - yield buffers[0], buffers[1], int(label) + yield buffers[0], buffers[1], int(labels_lines[i]) return reader @@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): total = len(batch_output) assert total > 0, "The batch output is empty." correct = 0 - for n, output in enumerate(batch_output[0]): + for n, output in enumerate(batch_output): max_idx = np.where(output == output.max()) - if max_idx == labels[n]: + if max_idx[0] == labels[n]: correct += 1 return correct + def set_config( + self, + model_path, + target='quant', + ): + config = Config(model_path) + config.disable_gpu() + config.switch_specify_input_names(True) + config.switch_ir_optim(True) + config.switch_use_feed_fetch_ops(True) + config.enable_mkldnn() + if target == 'int8': + config.enable_mkldnn_int8(self._quantized_ops) + config.delete_pass( + "constant_folding_pass" + ) # same reason as in analyzer_ernie_int8_tester.cc + return config + def _predict( self, test_reader=None, @@ -156,119 +171,89 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): batch_size=1, batch_num=1, skip_batch_num=0, - target='quant', + target='fp32', ): assert target in ['quant', 'int8', 'fp32'] - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - inference_scope = paddle.static.global_scope() - with paddle.static.scope_guard(inference_scope): - if os.path.exists(os.path.join(model_path, '__model__')): - [ - inference_program, - feed_target_names, - fetch_targets, - ] = paddle.fluid.io.load_inference_model(model_path, exe) - else: - [ - inference_program, - feed_target_names, - fetch_targets, - ] = paddle.static.load_inference_model( - model_path, - exe, - model_filename='model', - params_filename='params', - ) - - graph = IrGraph(core.Graph(inference_program.desc), for_test=True) - if self._debug: - graph.draw('.', 'quant_orig', graph.all_op_nodes()) - if target != 'quant': - quant_transform_pass = Quant2Int8MkldnnPass( - self._quantized_ops, - _op_ids_to_skip=self._op_ids_to_skip, - _scope=inference_scope, - _place=place, - _core=core, - _debug=self._debug, - ) - if target == 'int8': - graph = quant_transform_pass.apply(graph) - else: # target == fp32 - graph = quant_transform_pass.prepare_and_optimize_fp32( - graph - ) - - inference_program = graph.to_program() - - total_correct = 0 - total_samples = 0 - batch_times = [] - ppses = [] # predictions per second - iters = 0 - infer_start_time = time.time() - for data in test_reader(): - if batch_num > 0 and iters >= batch_num: - break - if iters == skip_batch_num: - total_samples = 0 - infer_start_time = time.time() - input0 = np.array([x[0] for x in data]).astype('int64') - input1 = np.array([x[1] for x in data]).astype('int64') - labels = np.array([x[2] for x in data]).astype('int64') - - start = time.time() - out = exe.run( - inference_program, - feed={ - feed_target_names[0]: input0, - feed_target_names[1]: input1, - }, - fetch_list=fetch_targets, - ) - batch_time = (time.time() - start) * 1000 # in miliseconds - batch_times.append(batch_time) - batch_correct = self._get_batch_correct(out, labels) - batch_len = len(data) - total_samples += batch_len - total_correct += batch_correct - batch_acc = float(batch_correct) / float(batch_len) - pps = batch_len / batch_time * 1000 - ppses.append(pps) - latency = batch_time / batch_len - iters += 1 - appx = ' (warm-up)' if iters <= skip_batch_num else '' - _logger.info( - 'batch {}{}, acc: {:.4f}, latency: {:.4f} ms, predictions per sec: {:.2f}'.format( - iters, appx, batch_acc, latency, pps - ) - ) - - # Postprocess benchmark data - infer_total_time = time.time() - infer_start_time - batch_latencies = batch_times[skip_batch_num:] - batch_latency_avg = np.average(batch_latencies) - latency_avg = batch_latency_avg / batch_size - ppses = ppses[skip_batch_num:] - pps_avg = np.average(ppses) - acc_avg = float(np.sum(total_correct)) / float(total_samples) - _logger.info(f'Total inference run time: {infer_total_time:.2f} s') - - return acc_avg, pps_avg, latency_avg + print(f"target: {target}, model path: {model_path}") + + config = self.set_config( + model_path, + target, + ) + predictor = create_predictor(config) + + input_names = predictor.get_input_names() + output_names = predictor.get_output_names() + + total_correct = 0 + total_samples = 0 + batch_times = [] + ppses = [] # predictions per second + iters = 0 + infer_start_time = time.time() + + for data in test_reader(): + if batch_num > 0 and iters >= batch_num: + break + if iters == skip_batch_num: + total_samples = 0 + infer_start_time = time.time() + # check data + inputs = [] + inputs.append(np.array([x[0] for x in data])) + inputs.append(np.array([x[1] for x in data])) + labels = np.array([x[2] for x in data]) + + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + input_tensor.reshape(inputs[i].shape) + input_tensor.copy_from_cpu(inputs[i].copy()) + + start = time.time() + predictor.run() + batch_time = (time.time() - start) * 1000 # in miliseconds + + out = [] + out = predictor.get_output_handle(output_names[0]).copy_to_cpu() + batch_times.append(batch_time) + batch_correct = self._get_batch_correct(out, labels) + batch_len = len(labels) + total_samples += batch_len + total_correct += batch_correct + batch_acc = float(batch_correct) / float(batch_len) + pps = batch_len / batch_time * 1000 + ppses.append(pps) + latency = batch_time / batch_len + iters += 1 + appx = ' (warm-up)' if iters <= skip_batch_num else '' + _logger.info( + f'batch {iters}{appx}, acc: {batch_acc:.4f}, latency: {latency:.4f} ms, predictions per sec: {pps:.2f}' + ) + # Postprocess benchmark data + infer_total_time = time.time() - infer_start_time + batch_latencies = batch_times[skip_batch_num:] + batch_latency_avg = np.average(batch_latencies) + latency_avg = batch_latency_avg / batch_size + ppses = ppses[skip_batch_num:] + pps_avg = np.average(ppses) + acc_avg = float(np.sum(total_correct)) / float(total_samples) + _logger.info(f'Total inference run time: {infer_total_time:.2f} s') + + return acc_avg, pps_avg, latency_avg def _print_performance(self, title, pps, lat): _logger.info( - '{}: avg predictions per sec: {:.2f}, avg latency: {:.4f} ms'.format( - title, pps, lat - ) + f'{title}: avg predictions per sec: {pps:.2f}, avg latency: {lat:.4f} ms' ) def _print_accuracy(self, title, acc): _logger.info(f'{title}: avg accuracy: {acc:.6f}') - def _summarize_performance(self, int8_pps, int8_lat, fp32_pps, fp32_lat): + def _summarize_performance( + self, quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat + ): _logger.info('--- Performance summary ---') + self._print_performance('QUANT', quant_pps, quant_lat) self._print_performance('INT8', int8_pps, int8_lat) if fp32_lat >= 0: self._print_performance('FP32', fp32_pps, fp32_lat) @@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): def _compare_accuracy(self, threshold, quant_acc, int8_acc): _logger.info( - 'Accepted accuracy drop threshold: {}. (condition: (Quant_acc - INT8_acc) <= threshold)'.format( - threshold - ) + f'Accepted accuracy drop threshold: {threshold}. (condition: (Quant_acc - INT8_acc) <= threshold)' ) # Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5 assert quant_acc > 0.5 @@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): return set(map(int, string.split(','))) def test_graph_transformation(self): - if not core.is_compiled_with_mkldnn(): + if not fluid.core.is_compiled_with_mkldnn(): return quant_model_path = test_case_args.quant_model @@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): self._print_performance('FP32', fp32_pps, fp32_lat) self._print_accuracy('FP32', fp32_acc) - if {'int8', 'fp32'}.issubset(self._targets): - self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) + if {'int8', 'quant', 'fp32'}.issubset(self._targets): + self._summarize_performance( + quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat + ) if {'int8', 'quant'}.issubset(self._targets): self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) -- GitLab