diff --git a/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py index b6eb825b09d2c0aad0d45049feafc09106d1e8c7..1a00f31df5b4c73c903c75b188b39952907c101e 100644 --- a/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py @@ -48,9 +48,11 @@ def parse_args(): parser.add_argument( '--qat_model', type=str, default='', help='A path to a QAT model.') parser.add_argument( - '--save_model', - action='store_true', - help='If used, the QAT model will be saved after all transformations') + '--fp32_model', + type=str, + default='', + help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.' + ) parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument( '--labels', type=str, default='', help='File with labels.') @@ -239,7 +241,10 @@ class QatInt8NLPComparisonTest(unittest.TestCase): return qat_model_path = test_case_args.qat_model + assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.' + fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path data_path = test_case_args.infer_data + assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.' labels_path = test_case_args.labels batch_size = test_case_args.batch_size batch_num = test_case_args.batch_num @@ -250,6 +255,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase): _logger.info('QAT FP32 & INT8 prediction run.') _logger.info('QAT model: {0}'.format(qat_model_path)) + _logger.info('FP32 model: {0}'.format(fp32_model_path)) _logger.info('Dataset: {0}'.format(data_path)) _logger.info('Labels: {0}'.format(labels_path)) _logger.info('Batch size: {0}'.format(batch_size)) @@ -262,11 +268,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase): self._reader_creator(data_path, labels_path), batch_size=batch_size) fp32_acc, fp32_pps, fp32_lat = self._predict( val_reader, - qat_model_path, + fp32_model_path, batch_size, batch_num, skip_batch_num, transform_to_int8=False) + _logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc)) _logger.info('--- QAT INT8 prediction start ---') val_reader = paddle.batch( self._reader_creator(data_path, labels_path), batch_size=batch_size) @@ -277,6 +284,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase): batch_num, skip_batch_num, transform_to_int8=True) + _logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc)) self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat) self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)