未验证 提交 43f3d0cc 编写于 作者: W Wojciech Uss 提交者: GitHub

Add an option to choose inference targets in Quant tests (#25582)

test=develop
上级 b158a21b
......@@ -68,6 +68,12 @@ def parse_args():
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--targets',
type=str,
default='quant,int8,fp32',
help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser.add_argument(
'--debug',
action='store_true',
......@@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
assert int8_acc1 > 0.5
assert quant_acc1 - int8_acc1 <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
......@@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._debug = test_case_args.debug
self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0:
self._quantized_ops = set(
op.strip() for op in test_case_args.ops_to_quantize.split(','))
if test_case_args.ops_to_quantize:
self._quantized_ops = self._strings_from_csv(
test_case_args.ops_to_quantize)
self._op_ids_to_skip = set([-1])
if len(test_case_args.op_ids_to_skip) > 0:
self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(',')))
if test_case_args.op_ids_to_skip:
self._op_ids_to_skip = self._ints_from_csv(
test_case_args.op_ids_to_skip)
self._targets = self._strings_from_csv(test_case_args.targets)
assert self._targets.intersection(
{'quant', 'int8', 'fp32'}
), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path))
......@@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none'))
_logger.info('Targets: {}.'.format(','.join(self._targets)))
_logger.info('--- Quant prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='quant')
self._print_performance('Quant', quant_fps, quant_lat)
self._print_accuracy('Quant', quant_acc1, quant_acc5)
if 'quant' in self._targets:
_logger.info('--- Quant prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='quant')
self._print_performance('Quant', quant_fps, quant_lat)
self._print_accuracy('Quant', quant_acc1, quant_acc5)
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='int8')
self._print_performance('INT8', int8_fps, int8_lat)
self._print_accuracy('INT8', int8_acc1, int8_acc5)
if 'int8' in self._targets:
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='int8')
self._print_performance('INT8', int8_fps, int8_lat)
self._print_accuracy('INT8', int8_acc1, int8_acc5)
fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1
if fp32_model_path:
if 'fp32' in self._targets and fp32_model_path:
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
......@@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_fps, fp32_lat)
self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, int8_acc5,
fp32_acc1, fp32_acc5)
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
if {'int8', 'fp32'}.issubset(self._targets):
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1,
int8_acc5, fp32_acc1, fp32_acc5)
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
if __name__ == '__main__':
......
......@@ -72,6 +72,12 @@ def parse_args():
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--targets',
type=str,
default='quant,int8,fp32',
help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser.add_argument(
'--debug',
action='store_true',
......@@ -256,6 +262,12 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
assert int8_acc > 0.5
assert quant_acc - int8_acc <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
......@@ -274,13 +286,18 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._quantized_ops = set()
if test_case_args.ops_to_quantize:
self._quantized_ops = set(
op.strip() for op in test_case_args.ops_to_quantize.split(','))
self._quantized_ops = self._strings_from_csv(
test_case_args.ops_to_quantize)
self._op_ids_to_skip = set([-1])
if test_case_args.op_ids_to_skip:
self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(',')))
self._op_ids_to_skip = self._ints_from_csv(
test_case_args.op_ids_to_skip)
self._targets = self._strings_from_csv(test_case_args.targets)
assert self._targets.intersection(
{'quant', 'int8', 'fp32'}
), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path))
......@@ -296,35 +313,40 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none'))
_logger.info('Targets: {}.'.format(','.join(self._targets)))
_logger.info('--- Quant prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
quant_acc, quant_pps, quant_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='quant')
self._print_performance('Quant', quant_pps, quant_lat)
self._print_accuracy('Quant', quant_acc)
if 'quant' in self._targets:
_logger.info('--- Quant prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path),
batch_size=batch_size)
quant_acc, quant_pps, quant_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='quant')
self._print_performance('Quant', quant_pps, quant_lat)
self._print_accuracy('Quant', quant_acc)
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='int8')
self._print_performance('INT8', int8_pps, int8_lat)
self._print_accuracy('INT8', int8_acc)
if 'int8' in self._targets:
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path),
batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict(
val_reader,
quant_model_path,
batch_size,
batch_num,
skip_batch_num,
target='int8')
self._print_performance('INT8', int8_pps, int8_lat)
self._print_accuracy('INT8', int8_acc)
fp32_acc = fp32_pps = fp32_lat = -1
if fp32_model_path:
if 'fp32' in self._targets and fp32_model_path:
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path),
......@@ -339,9 +361,11 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc)
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat)
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
if {'int8', 'fp32'}.issubset(self._targets):
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat)
if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册