未验证 提交 7ee31e72 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Rewrite quant2_int8_nlp_comparison test (#51995)

* Correct ernie int8 test to use new QAT process

* Add comment

* Fix code style

* Fix string formatting

* Fix cmake files

---------
Co-authored-by: Nwozna <joanna.wozna@intel.com>
上级 b0dbf9fe
...@@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN) ...@@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}
114f38804a3ef8c45e7259e68bbd838b) 114f38804a3ef8c45e7259e68bbd838b)
set(QUANT2_ERNIE_OPS_TO_QUANTIZE set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fused_matmul,matmul,matmul_v2,slice")
"fc,reshape2,transpose2,matmul,elementwise_add,slice")
inference_quant2_int8_nlp_test( inference_quant2_int8_nlp_test(
test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH}
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import argparse import argparse
import logging import logging
import os
import sys import sys
import time import time
import unittest import unittest
...@@ -22,9 +21,8 @@ import unittest ...@@ -22,9 +21,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle import fluid
from paddle.framework import core from paddle.inference import Config, create_predictor
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static() paddle.enable_static()
...@@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
buffer_np = np.array(buffer_i).astype("int64") buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np) buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np) buffers.append(buffer_np)
label = labels_lines[i] yield buffers[0], buffers[1], int(labels_lines[i])
yield buffers[0], buffers[1], int(label)
return reader return reader
...@@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
total = len(batch_output) total = len(batch_output)
assert total > 0, "The batch output is empty." assert total > 0, "The batch output is empty."
correct = 0 correct = 0
for n, output in enumerate(batch_output[0]): for n, output in enumerate(batch_output):
max_idx = np.where(output == output.max()) max_idx = np.where(output == output.max())
if max_idx == labels[n]: if max_idx[0] == labels[n]:
correct += 1 correct += 1
return correct return correct
def set_config(
self,
model_path,
target='quant',
):
config = Config(model_path)
config.disable_gpu()
config.switch_specify_input_names(True)
config.switch_ir_optim(True)
config.switch_use_feed_fetch_ops(True)
config.enable_mkldnn()
if target == 'int8':
config.enable_mkldnn_int8(self._quantized_ops)
config.delete_pass(
"constant_folding_pass"
) # same reason as in analyzer_ernie_int8_tester.cc
return config
def _predict( def _predict(
self, self,
test_reader=None, test_reader=None,
...@@ -156,51 +171,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -156,51 +171,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size=1, batch_size=1,
batch_num=1, batch_num=1,
skip_batch_num=0, skip_batch_num=0,
target='quant', target='fp32',
): ):
assert target in ['quant', 'int8', 'fp32'] assert target in ['quant', 'int8', 'fp32']
place = paddle.CPUPlace() print(f"target: {target}, model path: {model_path}")
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) config = self.set_config(
if self._debug: model_path,
graph.draw('.', 'quant_orig', graph.all_op_nodes()) target,
if target != 'quant':
quant_transform_pass = Quant2Int8MkldnnPass(
self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug,
)
if target == 'int8':
graph = quant_transform_pass.apply(graph)
else: # target == fp32
graph = quant_transform_pass.prepare_and_optimize_fp32(
graph
) )
predictor = create_predictor(config)
inference_program = graph.to_program() input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
total_correct = 0 total_correct = 0
total_samples = 0 total_samples = 0
...@@ -208,29 +191,33 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -208,29 +191,33 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
ppses = [] # predictions per second ppses = [] # predictions per second
iters = 0 iters = 0
infer_start_time = time.time() infer_start_time = time.time()
for data in test_reader(): for data in test_reader():
if batch_num > 0 and iters >= batch_num: if batch_num > 0 and iters >= batch_num:
break break
if iters == skip_batch_num: if iters == skip_batch_num:
total_samples = 0 total_samples = 0
infer_start_time = time.time() infer_start_time = time.time()
input0 = np.array([x[0] for x in data]).astype('int64') # check data
input1 = np.array([x[1] for x in data]).astype('int64') inputs = []
labels = np.array([x[2] for x in data]).astype('int64') inputs.append(np.array([x[0] for x in data]))
inputs.append(np.array([x[1] for x in data]))
labels = np.array([x[2] for x in data])
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(inputs[i].shape)
input_tensor.copy_from_cpu(inputs[i].copy())
start = time.time() start = time.time()
out = exe.run( predictor.run()
inference_program,
feed={
feed_target_names[0]: input0,
feed_target_names[1]: input1,
},
fetch_list=fetch_targets,
)
batch_time = (time.time() - start) * 1000 # in miliseconds batch_time = (time.time() - start) * 1000 # in miliseconds
out = []
out = predictor.get_output_handle(output_names[0]).copy_to_cpu()
batch_times.append(batch_time) batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels) batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data) batch_len = len(labels)
total_samples += batch_len total_samples += batch_len
total_correct += batch_correct total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len) batch_acc = float(batch_correct) / float(batch_len)
...@@ -240,11 +227,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -240,11 +227,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
iters += 1 iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else '' appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info( _logger.info(
'batch {}{}, acc: {:.4f}, latency: {:.4f} ms, predictions per sec: {:.2f}'.format( f'batch {iters}{appx}, acc: {batch_acc:.4f}, latency: {latency:.4f} ms, predictions per sec: {pps:.2f}'
iters, appx, batch_acc, latency, pps
)
) )
# Postprocess benchmark data # Postprocess benchmark data
infer_total_time = time.time() - infer_start_time infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:] batch_latencies = batch_times[skip_batch_num:]
...@@ -259,16 +243,17 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -259,16 +243,17 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
def _print_performance(self, title, pps, lat): def _print_performance(self, title, pps, lat):
_logger.info( _logger.info(
'{}: avg predictions per sec: {:.2f}, avg latency: {:.4f} ms'.format( f'{title}: avg predictions per sec: {pps:.2f}, avg latency: {lat:.4f} ms'
title, pps, lat
)
) )
def _print_accuracy(self, title, acc): def _print_accuracy(self, title, acc):
_logger.info(f'{title}: avg accuracy: {acc:.6f}') _logger.info(f'{title}: avg accuracy: {acc:.6f}')
def _summarize_performance(self, int8_pps, int8_lat, fp32_pps, fp32_lat): def _summarize_performance(
self, quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
):
_logger.info('--- Performance summary ---') _logger.info('--- Performance summary ---')
self._print_performance('QUANT', quant_pps, quant_lat)
self._print_performance('INT8', int8_pps, int8_lat) self._print_performance('INT8', int8_pps, int8_lat)
if fp32_lat >= 0: if fp32_lat >= 0:
self._print_performance('FP32', fp32_pps, fp32_lat) self._print_performance('FP32', fp32_pps, fp32_lat)
...@@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
def _compare_accuracy(self, threshold, quant_acc, int8_acc): def _compare_accuracy(self, threshold, quant_acc, int8_acc):
_logger.info( _logger.info(
'Accepted accuracy drop threshold: {}. (condition: (Quant_acc - INT8_acc) <= threshold)'.format( f'Accepted accuracy drop threshold: {threshold}. (condition: (Quant_acc - INT8_acc) <= threshold)'
threshold
)
) )
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5 # Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert quant_acc > 0.5 assert quant_acc > 0.5
...@@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return set(map(int, string.split(','))) return set(map(int, string.split(',')))
def test_graph_transformation(self): def test_graph_transformation(self):
if not core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
quant_model_path = test_case_args.quant_model quant_model_path = test_case_args.quant_model
...@@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_pps, fp32_lat) self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc) self._print_accuracy('FP32', fp32_acc)
if {'int8', 'fp32'}.issubset(self._targets): if {'int8', 'quant', 'fp32'}.issubset(self._targets):
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) self._summarize_performance(
quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
)
if {'int8', 'quant'}.issubset(self._targets): if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册