未验证 提交 9ba52165 编写于 作者: W Wojciech Uss 提交者: GitHub

Added an option to use external FP32 model in QAT comparison test (#22858) (#22873)

上级 781f8b2f
......@@ -48,9 +48,11 @@ def parse_args():
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
'--fp32_model',
type=str,
default='',
help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.'
)
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--labels', type=str, default='', help='File with labels.')
......@@ -239,7 +241,10 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
return
qat_model_path = test_case_args.qat_model
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.'
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path
data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
labels_path = test_case_args.labels
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
......@@ -250,6 +255,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('FP32 model: {0}'.format(fp32_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
......@@ -262,11 +268,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
qat_model_path,
fp32_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
......@@ -277,6 +284,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
batch_num,
skip_batch_num,
transform_to_int8=True)
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册