未验证 提交 43f3d0cc 编写于 作者: W Wojciech Uss 提交者: GitHub

Add an option to choose inference targets in Quant tests (#25582)

test=develop
上级 b158a21b
...@@ -68,6 +68,12 @@ def parse_args(): ...@@ -68,6 +68,12 @@ def parse_args():
type=str, type=str,
default='', default='',
help='A comma separated list of operator ids to skip in quantization.') help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--targets',
type=str,
default='quant,int8,fp32',
help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
...@@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
assert int8_acc1 > 0.5 assert int8_acc1 > 0.5
assert quant_acc1 - int8_acc1 <= threshold assert quant_acc1 - int8_acc1 <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
def test_graph_transformation(self): def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
...@@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._debug = test_case_args.debug self._debug = test_case_args.debug
self._quantized_ops = set() self._quantized_ops = set()
if len(test_case_args.ops_to_quantize) > 0: if test_case_args.ops_to_quantize:
self._quantized_ops = set( self._quantized_ops = self._strings_from_csv(
op.strip() for op in test_case_args.ops_to_quantize.split(',')) test_case_args.ops_to_quantize)
self._op_ids_to_skip = set([-1]) self._op_ids_to_skip = set([-1])
if len(test_case_args.op_ids_to_skip) > 0: if test_case_args.op_ids_to_skip:
self._op_ids_to_skip = set( self._op_ids_to_skip = self._ints_from_csv(
map(int, test_case_args.op_ids_to_skip.split(','))) test_case_args.op_ids_to_skip)
self._targets = self._strings_from_csv(test_case_args.targets)
assert self._targets.intersection(
{'quant', 'int8', 'fp32'}
), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger.info('Quant & INT8 prediction run.') _logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
...@@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
_logger.info('Op ids to skip quantization: {}.'.format(','.join( _logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none')) else 'none'))
_logger.info('Targets: {}.'.format(','.join(self._targets)))
_logger.info('--- Quant prediction start ---') if 'quant' in self._targets:
val_reader = paddle.batch( _logger.info('--- Quant prediction start ---')
self._reader_creator(data_path), batch_size=batch_size) val_reader = paddle.batch(
quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict( self._reader_creator(data_path), batch_size=batch_size)
val_reader, quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
quant_model_path, val_reader,
batch_size, quant_model_path,
batch_num, batch_size,
skip_batch_num, batch_num,
target='quant') skip_batch_num,
self._print_performance('Quant', quant_fps, quant_lat) target='quant')
self._print_accuracy('Quant', quant_acc1, quant_acc5) self._print_performance('Quant', quant_fps, quant_lat)
self._print_accuracy('Quant', quant_acc1, quant_acc5)
_logger.info('--- INT8 prediction start ---') if 'int8' in self._targets:
val_reader = paddle.batch( _logger.info('--- INT8 prediction start ---')
self._reader_creator(data_path), batch_size=batch_size) val_reader = paddle.batch(
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( self._reader_creator(data_path), batch_size=batch_size)
val_reader, int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
quant_model_path, val_reader,
batch_size, quant_model_path,
batch_num, batch_size,
skip_batch_num, batch_num,
target='int8') skip_batch_num,
self._print_performance('INT8', int8_fps, int8_lat) target='int8')
self._print_accuracy('INT8', int8_acc1, int8_acc5) self._print_performance('INT8', int8_fps, int8_lat)
self._print_accuracy('INT8', int8_acc1, int8_acc5)
fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1 fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1
if fp32_model_path: if 'fp32' in self._targets and fp32_model_path:
_logger.info('--- FP32 prediction start ---') _logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
...@@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_fps, fp32_lat) self._print_performance('FP32', fp32_fps, fp32_lat)
self._print_accuracy('FP32', fp32_acc1, fp32_acc5) self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat) if {'int8', 'fp32'}.issubset(self._targets):
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, int8_acc5, self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
fp32_acc1, fp32_acc5) if {'int8', 'quant'}.issubset(self._targets):
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1) self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1,
int8_acc5, fp32_acc1, fp32_acc5)
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,6 +72,12 @@ def parse_args(): ...@@ -72,6 +72,12 @@ def parse_args():
type=str, type=str,
default='', default='',
help='A comma separated list of operator ids to skip in quantization.') help='A comma separated list of operator ids to skip in quantization.')
parser.add_argument(
'--targets',
type=str,
default='quant,int8,fp32',
help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
...@@ -256,6 +262,12 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -256,6 +262,12 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
assert int8_acc > 0.5 assert int8_acc > 0.5
assert quant_acc - int8_acc <= threshold assert quant_acc - int8_acc <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
def test_graph_transformation(self): def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
...@@ -274,13 +286,18 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -274,13 +286,18 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._quantized_ops = set() self._quantized_ops = set()
if test_case_args.ops_to_quantize: if test_case_args.ops_to_quantize:
self._quantized_ops = set( self._quantized_ops = self._strings_from_csv(
op.strip() for op in test_case_args.ops_to_quantize.split(',')) test_case_args.ops_to_quantize)
self._op_ids_to_skip = set([-1]) self._op_ids_to_skip = set([-1])
if test_case_args.op_ids_to_skip: if test_case_args.op_ids_to_skip:
self._op_ids_to_skip = set( self._op_ids_to_skip = self._ints_from_csv(
map(int, test_case_args.op_ids_to_skip.split(','))) test_case_args.op_ids_to_skip)
self._targets = self._strings_from_csv(test_case_args.targets)
assert self._targets.intersection(
{'quant', 'int8', 'fp32'}
), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger.info('Quant & INT8 prediction run.') _logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
...@@ -296,35 +313,40 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -296,35 +313,40 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
_logger.info('Op ids to skip quantization: {}.'.format(','.join( _logger.info('Op ids to skip quantization: {}.'.format(','.join(
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none')) else 'none'))
_logger.info('Targets: {}.'.format(','.join(self._targets)))
_logger.info('--- Quant prediction start ---') if 'quant' in self._targets:
val_reader = paddle.batch( _logger.info('--- Quant prediction start ---')
self._reader_creator(data_path, labels_path), batch_size=batch_size) val_reader = paddle.batch(
quant_acc, quant_pps, quant_lat = self._predict( self._reader_creator(data_path, labels_path),
val_reader, batch_size=batch_size)
quant_model_path, quant_acc, quant_pps, quant_lat = self._predict(
batch_size, val_reader,
batch_num, quant_model_path,
skip_batch_num, batch_size,
target='quant') batch_num,
self._print_performance('Quant', quant_pps, quant_lat) skip_batch_num,
self._print_accuracy('Quant', quant_acc) target='quant')
self._print_performance('Quant', quant_pps, quant_lat)
self._print_accuracy('Quant', quant_acc)
_logger.info('--- INT8 prediction start ---') if 'int8' in self._targets:
val_reader = paddle.batch( _logger.info('--- INT8 prediction start ---')
self._reader_creator(data_path, labels_path), batch_size=batch_size) val_reader = paddle.batch(
int8_acc, int8_pps, int8_lat = self._predict( self._reader_creator(data_path, labels_path),
val_reader, batch_size=batch_size)
quant_model_path, int8_acc, int8_pps, int8_lat = self._predict(
batch_size, val_reader,
batch_num, quant_model_path,
skip_batch_num, batch_size,
target='int8') batch_num,
self._print_performance('INT8', int8_pps, int8_lat) skip_batch_num,
self._print_accuracy('INT8', int8_acc) target='int8')
self._print_performance('INT8', int8_pps, int8_lat)
self._print_accuracy('INT8', int8_acc)
fp32_acc = fp32_pps = fp32_lat = -1 fp32_acc = fp32_pps = fp32_lat = -1
if fp32_model_path: if 'fp32' in self._targets and fp32_model_path:
_logger.info('--- FP32 prediction start ---') _logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), self._reader_creator(data_path, labels_path),
...@@ -339,9 +361,11 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -339,9 +361,11 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._print_performance('FP32', fp32_pps, fp32_lat) self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc) self._print_accuracy('FP32', fp32_acc)
self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) if {'int8', 'fp32'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) if {'int8', 'quant'}.issubset(self._targets):
self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册