未验证 提交 7ee31e72 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Rewrite quant2_int8_nlp_comparison test (#51995)

* Correct ernie int8 test to use new QAT process

* Add comment

* Fix code style

* Fix string formatting

* Fix cmake files

---------
Co-authored-by: Nwozna <joanna.wozna@intel.com>
上级 b0dbf9fe
...@@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN) ...@@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}
114f38804a3ef8c45e7259e68bbd838b) 114f38804a3ef8c45e7259e68bbd838b)
set(QUANT2_ERNIE_OPS_TO_QUANTIZE set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fused_matmul,matmul,matmul_v2,slice")
"fc,reshape2,transpose2,matmul,elementwise_add,slice")
inference_quant2_int8_nlp_test( inference_quant2_int8_nlp_test(
test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH}
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import argparse import argparse
import logging import logging
import os
import sys import sys
import time import time
import unittest import unittest
...@@ -22,9 +21,8 @@ import unittest ...@@ -22,9 +21,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle import fluid
from paddle.framework import core from paddle.inference import Config, create_predictor
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static() paddle.enable_static()
...@@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
buffer_np = np.array(buffer_i).astype("int64") buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np) buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np) buffers.append(buffer_np)
label = labels_lines[i] yield buffers[0], buffers[1], int(labels_lines[i])
yield buffers[0], buffers[1], int(label)
return reader return reader
...@@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
total = len(batch_output) total = len(batch_output)
assert total > 0, "The batch output is empty." assert total > 0, "The batch output is empty."
correct = 0 correct = 0
for n, output in enumerate(batch_output[0]): for n, output in enumerate(batch_output):
max_idx = np.where(output == output.max()) max_idx = np.where(output == output.max())
if max_idx == labels[n]: if max_idx[0] == labels[n]:
correct += 1 correct += 1
return correct return correct
def set_config(
self,
model_path,
target='quant',
):
config = Config(model_path)
config.disable_gpu()
config.switch_specify_input_names(True)
config.switch_ir_optim(True)
config.switch_use_feed_fetch_ops(True)
config.enable_mkldnn()
if target == 'int8':
config.enable_mkldnn_int8(self._quantized_ops)
config.delete_pass(
"constant_folding_pass"
) # same reason as in analyzer_ernie_int8_tester.cc
return config
def _predict( def _predict(
self, self,
test_reader=None, test_reader=None,
...@@ -156,119 +171,89 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -156,119 +171,89 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size=1, batch_size=1,
batch_num=1, batch_num=1,
skip_batch_num=0, skip_batch_num=0,
target='quant', target='fp32',
): ):
assert target in ['quant', 'int8', 'fp32'] assert target in ['quant', 'int8', 'fp32']
place = paddle.CPUPlace() print(f"target: {target}, model path: {model_path}")
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope() config = self.set_config(
with paddle.static.scope_guard(inference_scope): model_path,
if os.path.exists(os.path.join(model_path, '__model__')): target,
[ )
inference_program, predictor = create_predictor(config)
feed_target_names,
fetch_targets, input_names = predictor.get_input_names()
] = paddle.fluid.io.load_inference_model(model_path, exe) output_names = predictor.get_output_names()
else:
[ total_correct = 0
inference_program, total_samples = 0
feed_target_names, batch_times = []
fetch_targets, ppses = [] # predictions per second
] = paddle.static.load_inference_model( iters = 0
model_path, infer_start_time = time.time()
exe,
model_filename='model', for data in test_reader():
params_filename='params', if batch_num > 0 and iters >= batch_num:
) break
if iters == skip_batch_num:
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) total_samples = 0
if self._debug: infer_start_time = time.time()
graph.draw('.', 'quant_orig', graph.all_op_nodes()) # check data
if target != 'quant': inputs = []
quant_transform_pass = Quant2Int8MkldnnPass( inputs.append(np.array([x[0] for x in data]))
self._quantized_ops, inputs.append(np.array([x[1] for x in data]))
_op_ids_to_skip=self._op_ids_to_skip, labels = np.array([x[2] for x in data])
_scope=inference_scope,
_place=place, for i, name in enumerate(input_names):
_core=core, input_tensor = predictor.get_input_handle(name)
_debug=self._debug, input_tensor.reshape(inputs[i].shape)
) input_tensor.copy_from_cpu(inputs[i].copy())
if target == 'int8':
graph = quant_transform_pass.apply(graph) start = time.time()
else: # target == fp32 predictor.run()
graph = quant_transform_pass.prepare_and_optimize_fp32( batch_time = (time.time() - start) * 1000 # in miliseconds
graph
) out = []
out = predictor.get_output_handle(output_names[0]).copy_to_cpu()
inference_program = graph.to_program() batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
total_correct = 0 batch_len = len(labels)
total_samples = 0 total_samples += batch_len
batch_times = [] total_correct += batch_correct
ppses = [] # predictions per second batch_acc = float(batch_correct) / float(batch_len)
iters = 0 pps = batch_len / batch_time * 1000
infer_start_time = time.time() ppses.append(pps)
for data in test_reader(): latency = batch_time / batch_len
if batch_num > 0 and iters >= batch_num: iters += 1
break appx = ' (warm-up)' if iters <= skip_batch_num else ''
if iters == skip_batch_num: _logger.info(
total_samples = 0 f'batch {iters}{appx}, acc: {batch_acc:.4f}, latency: {latency:.4f} ms, predictions per sec: {pps:.2f}'
infer_start_time = time.time() )
input0 = np.array([x[0] for x in data]).astype('int64') # Postprocess benchmark data
input1 = np.array([x[1] for x in data]).astype('int64') infer_total_time = time.time() - infer_start_time
labels = np.array([x[2] for x in data]).astype('int64') batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
start = time.time() latency_avg = batch_latency_avg / batch_size
out = exe.run( ppses = ppses[skip_batch_num:]
inference_program, pps_avg = np.average(ppses)
feed={ acc_avg = float(np.sum(total_correct)) / float(total_samples)
feed_target_names[0]: input0, _logger.info(f'Total inference run time: {infer_total_time:.2f} s')
feed_target_names[1]: input1,
}, return acc_avg, pps_avg, latency_avg
fetch_list=fetch_targets,
)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {}{}, acc: {:.4f}, latency: {:.4f} ms, predictions per sec: {:.2f}'.format(
iters, appx, batch_acc, latency, pps
)
)
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info(f'Total inference run time: {infer_total_time:.2f} s')
return acc_avg, pps_avg, latency_avg
def _print_performance(self, title, pps, lat): def _print_performance(self, title, pps, lat):
_logger.info( _logger.info(
'{}: avg predictions per sec: {:.2f}, avg latency: {:.4f} ms'.format( f'{title}: avg predictions per sec: {pps:.2f}, avg latency: {lat:.4f} ms'
title, pps, lat
)
) )
def _print_accuracy(self, title, acc): def _print_accuracy(self, title, acc):
_logger.info(f'{title}: avg accuracy: {acc:.6f}') _logger.info(f'{title}: avg accuracy: {acc:.6f}')
def _summarize_performance(self, int8_pps, int8_lat, fp32_pps, fp32_lat): def _summarize_performance(
self, quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
):
_logger.info('--- Performance summary ---') _logger.info('--- Performance summary ---')
self._print_performance('QUANT', quant_pps, quant_lat)
self._print_performance('INT8', int8_pps, int8_lat) self._print_performance('INT8', int8_pps, int8_lat)
if fp32_lat >= 0: if fp32_lat >= 0:
self._print_performance('FP32', fp32_pps, fp32_lat) self._print_performance('FP32', fp32_pps, fp32_lat)
...@@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
def _compare_accuracy(self, threshold, quant_acc, int8_acc): def _compare_accuracy(self, threshold, quant_acc, int8_acc):
_logger.info( _logger.info(
'Accepted accuracy drop threshold: {}. (condition: (Quant_acc - INT8_acc) <= threshold)'.format( f'Accepted accuracy drop threshold: {threshold}. (condition: (Quant_acc - INT8_acc) <= threshold)'
threshold
)
) )
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5 # Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert quant_acc > 0.5 assert quant_acc > 0.5
...@@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return set(map(int, string.split(','))) return set(map(int, string.split(',')))
def test_graph_transformation(self): def test_graph_transformation(self):
if not core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
quant_model_path = test_case_args.quant_model quant_model_path = test_case_args.quant_model
...@@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_pps, fp32_lat) self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc) self._print_accuracy('FP32', fp32_acc)
if {'int8', 'fp32'}.issubset(self._targets): if {'int8', 'quant', 'fp32'}.issubset(self._targets):
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) self._summarize_performance(
quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
)
if {'int8', 'quant'}.issubset(self._targets): if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册