未验证 提交 7ee31e72 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Rewrite quant2_int8_nlp_comparison test (#51995)

* Correct ernie int8 test to use new QAT process

* Add comment

* Fix code style

* Fix string formatting

* Fix cmake files

---------
Co-authored-by: Nwozna <joanna.wozna@intel.com>
上级 b0dbf9fe
......@@ -393,8 +393,7 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}
114f38804a3ef8c45e7259e68bbd838b)
set(QUANT2_ERNIE_OPS_TO_QUANTIZE
"fc,reshape2,transpose2,matmul,elementwise_add,slice")
set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fused_matmul,matmul,matmul_v2,slice")
inference_quant2_int8_nlp_test(
test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH}
......
......@@ -14,7 +14,6 @@
import argparse
import logging
import os
import sys
import time
import unittest
......@@ -22,9 +21,8 @@ import unittest
import numpy as np
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
from paddle import fluid
from paddle.inference import Config, create_predictor
paddle.enable_static()
......@@ -134,8 +132,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np)
label = labels_lines[i]
yield buffers[0], buffers[1], int(label)
yield buffers[0], buffers[1], int(labels_lines[i])
return reader
......@@ -143,12 +140,30 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
total = len(batch_output)
assert total > 0, "The batch output is empty."
correct = 0
for n, output in enumerate(batch_output[0]):
for n, output in enumerate(batch_output):
max_idx = np.where(output == output.max())
if max_idx == labels[n]:
if max_idx[0] == labels[n]:
correct += 1
return correct
def set_config(
self,
model_path,
target='quant',
):
config = Config(model_path)
config.disable_gpu()
config.switch_specify_input_names(True)
config.switch_ir_optim(True)
config.switch_use_feed_fetch_ops(True)
config.enable_mkldnn()
if target == 'int8':
config.enable_mkldnn_int8(self._quantized_ops)
config.delete_pass(
"constant_folding_pass"
) # same reason as in analyzer_ernie_int8_tester.cc
return config
def _predict(
self,
test_reader=None,
......@@ -156,119 +171,89 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size=1,
batch_num=1,
skip_batch_num=0,
target='quant',
target='fp32',
):
assert target in ['quant', 'int8', 'fp32']
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(model_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename='model',
params_filename='params',
)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if self._debug:
graph.draw('.', 'quant_orig', graph.all_op_nodes())
if target != 'quant':
quant_transform_pass = Quant2Int8MkldnnPass(
self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug,
)
if target == 'int8':
graph = quant_transform_pass.apply(graph)
else: # target == fp32
graph = quant_transform_pass.prepare_and_optimize_fp32(
graph
)
inference_program = graph.to_program()
total_correct = 0
total_samples = 0
batch_times = []
ppses = [] # predictions per second
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
input0 = np.array([x[0] for x in data]).astype('int64')
input1 = np.array([x[1] for x in data]).astype('int64')
labels = np.array([x[2] for x in data]).astype('int64')
start = time.time()
out = exe.run(
inference_program,
feed={
feed_target_names[0]: input0,
feed_target_names[1]: input1,
},
fetch_list=fetch_targets,
)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {}{}, acc: {:.4f}, latency: {:.4f} ms, predictions per sec: {:.2f}'.format(
iters, appx, batch_acc, latency, pps
)
)
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info(f'Total inference run time: {infer_total_time:.2f} s')
return acc_avg, pps_avg, latency_avg
print(f"target: {target}, model path: {model_path}")
config = self.set_config(
model_path,
target,
)
predictor = create_predictor(config)
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
total_correct = 0
total_samples = 0
batch_times = []
ppses = [] # predictions per second
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
# check data
inputs = []
inputs.append(np.array([x[0] for x in data]))
inputs.append(np.array([x[1] for x in data]))
labels = np.array([x[2] for x in data])
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(inputs[i].shape)
input_tensor.copy_from_cpu(inputs[i].copy())
start = time.time()
predictor.run()
batch_time = (time.time() - start) * 1000 # in miliseconds
out = []
out = predictor.get_output_handle(output_names[0]).copy_to_cpu()
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(labels)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
f'batch {iters}{appx}, acc: {batch_acc:.4f}, latency: {latency:.4f} ms, predictions per sec: {pps:.2f}'
)
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info(f'Total inference run time: {infer_total_time:.2f} s')
return acc_avg, pps_avg, latency_avg
def _print_performance(self, title, pps, lat):
_logger.info(
'{}: avg predictions per sec: {:.2f}, avg latency: {:.4f} ms'.format(
title, pps, lat
)
f'{title}: avg predictions per sec: {pps:.2f}, avg latency: {lat:.4f} ms'
)
def _print_accuracy(self, title, acc):
_logger.info(f'{title}: avg accuracy: {acc:.6f}')
def _summarize_performance(self, int8_pps, int8_lat, fp32_pps, fp32_lat):
def _summarize_performance(
self, quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
):
_logger.info('--- Performance summary ---')
self._print_performance('QUANT', quant_pps, quant_lat)
self._print_performance('INT8', int8_pps, int8_lat)
if fp32_lat >= 0:
self._print_performance('FP32', fp32_pps, fp32_lat)
......@@ -282,9 +267,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
def _compare_accuracy(self, threshold, quant_acc, int8_acc):
_logger.info(
'Accepted accuracy drop threshold: {}. (condition: (Quant_acc - INT8_acc) <= threshold)'.format(
threshold
)
f'Accepted accuracy drop threshold: {threshold}. (condition: (Quant_acc - INT8_acc) <= threshold)'
)
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert quant_acc > 0.5
......@@ -298,7 +281,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return set(map(int, string.split(',')))
def test_graph_transformation(self):
if not core.is_compiled_with_mkldnn():
if not fluid.core.is_compiled_with_mkldnn():
return
quant_model_path = test_case_args.quant_model
......@@ -411,8 +394,10 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc)
if {'int8', 'fp32'}.issubset(self._targets):
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat)
if {'int8', 'quant', 'fp32'}.issubset(self._targets):
self._summarize_performance(
quant_pps, quant_lat, int8_pps, int8_lat, fp32_pps, fp32_lat
)
if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册