未验证 提交 645e81f0 编写于 作者: F Frank Lin 提交者: GitHub

Improve stablity of Paddle-TensorRT FP16 UT GitHub (1) (#51554)

* Improve Readability and Overall Clarity of Logging

* Adds the set_input_type API for specifying input data types

* Specifying input data types
上级 4b85e5db
...@@ -36,7 +36,8 @@ import paddle ...@@ -36,7 +36,8 @@ import paddle
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.fluid.core import PassVersionChecker from paddle.fluid.core import PassVersionChecker
logging.basicConfig(level=logging.INFO, format="%(message)s") LOGLEVEL = os.environ.get("PADDLE_TEST_LOGLEVEL", "INFO").upper()
logging.basicConfig(level=LOGLEVEL, format="%(message)s")
settings.register_profile( settings.register_profile(
"ci", "ci",
...@@ -57,8 +58,8 @@ settings.register_profile( ...@@ -57,8 +58,8 @@ settings.register_profile(
report_multiple_bugs=False, report_multiple_bugs=False,
) )
if ( if (
float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1 float(os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")) < 1
or os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci' or os.getenv("HYPOTHESIS_TEST_PROFILE", "dev") == "ci"
): ):
settings.load_profile("ci") settings.load_profile("ci")
else: else:
...@@ -100,10 +101,10 @@ class AutoScanTest(unittest.TestCase): ...@@ -100,10 +101,10 @@ class AutoScanTest(unittest.TestCase):
@abc.abstractmethod @abc.abstractmethod
def sample_program_configs(self): def sample_program_configs(self):
''' """
Generate all config with the combination of different Input tensor shape and Generate all config with the combination of different Input tensor shape and
different Attr values. different Attr values.
''' """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
...@@ -125,9 +126,9 @@ class AutoScanTest(unittest.TestCase): ...@@ -125,9 +126,9 @@ class AutoScanTest(unittest.TestCase):
def run_test_config( def run_test_config(
self, model, params, prog_config, pred_config, feed_data self, model, params, prog_config, pred_config, feed_data
) -> Dict[str, np.ndarray]: ) -> Dict[str, np.ndarray]:
''' """
Test a single case. Test a single case.
''' """
pred_config.set_model_buffer(model, len(model), params, len(params)) pred_config.set_model_buffer(model, len(model), params, len(params))
predictor = paddle_infer.create_predictor(pred_config) predictor = paddle_infer.create_predictor(pred_config)
self.available_passes_in_framework = ( self.available_passes_in_framework = (
...@@ -136,9 +137,9 @@ class AutoScanTest(unittest.TestCase): ...@@ -136,9 +137,9 @@ class AutoScanTest(unittest.TestCase):
) )
for name, _ in prog_config.inputs.items(): for name, _ in prog_config.inputs.items():
input_tensor = predictor.get_input_handle(name) input_tensor = predictor.get_input_handle(name)
input_tensor.copy_from_cpu(feed_data[name]['data']) input_tensor.copy_from_cpu(feed_data[name]["data"])
if feed_data[name]['lod'] is not None: if feed_data[name]["lod"] is not None:
input_tensor.set_lod(feed_data[name]['lod']) input_tensor.set_lod(feed_data[name]["lod"])
predictor.run() predictor.run()
result = {} result = {}
for out_name, o_name in zip( for out_name, o_name in zip(
...@@ -158,10 +159,7 @@ class AutoScanTest(unittest.TestCase): ...@@ -158,10 +159,7 @@ class AutoScanTest(unittest.TestCase):
for key, arr in tensor.items(): for key, arr in tensor.items():
self.assertTrue( self.assertTrue(
baseline[key].shape == arr.shape, baseline[key].shape == arr.shape,
"The output shapes are not equal, the baseline shape is " f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
) )
diff = abs(baseline[key] - arr) diff = abs(baseline[key] - arr)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -169,9 +167,7 @@ class AutoScanTest(unittest.TestCase): ...@@ -169,9 +167,7 @@ class AutoScanTest(unittest.TestCase):
arr, arr,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
err_msg='Output has diff, Maximum absolute error: {}'.format( err_msg=f"Output has diff, Maximum absolute error: {np.amax(diff)}",
np.amax(diff)
),
) )
@abc.abstractmethod @abc.abstractmethod
...@@ -207,15 +203,19 @@ class AutoScanTest(unittest.TestCase): ...@@ -207,15 +203,19 @@ class AutoScanTest(unittest.TestCase):
@abc.abstractmethod @abc.abstractmethod
def ignore_log(self, msg: str): def ignore_log(self, msg: str):
logging.warning("SKIP: " + msg) logging.debug(f"SKIP: {msg}")
@abc.abstractmethod @abc.abstractmethod
def fail_log(self, msg: str): def fail_log(self, msg: str):
logging.error("FAIL: " + msg) logging.error(f"FAIL: {msg}")
@abc.abstractmethod
def info_log(self, msg: str):
logging.debug(f"INFO: {msg}")
@abc.abstractmethod @abc.abstractmethod
def success_log(self, msg: str): def success_log(self, msg: str):
logging.info("SUCCESS: " + msg) logging.debug(f"SUCCESS: {msg}")
@abc.abstractmethod @abc.abstractmethod
def create_inference_config( def create_inference_config(
...@@ -263,20 +263,22 @@ class MkldnnAutoScanTest(AutoScanTest): ...@@ -263,20 +263,22 @@ class MkldnnAutoScanTest(AutoScanTest):
feed_data = {} feed_data = {}
for name, tensor_config in prog_config.inputs.items(): for name, tensor_config in prog_config.inputs.items():
feed_data[name] = { feed_data[name] = {
'data': tensor_config.data, "data": tensor_config.data,
'lod': tensor_config.lod, "lod": tensor_config.lod,
} }
results: List[Dict[str, np.ndarray]] = [] results: List[Dict[str, np.ndarray]] = []
# baseline: cpu no ir_optim run # baseline: cpu no ir_optim run
base_config = self.create_inference_config(ir_optim=False) base_config = self.create_inference_config(ir_optim=False)
logging.info('RUN program_config: ' + str(prog_config))
results.append( results.append(
self.run_test_config( self.run_test_config(
model, params, prog_config, base_config, feed_data model, params, prog_config, base_config, feed_data
) )
) )
self.success_log('RUN_CPU_BASELINE done') self.success_log(f"basline program_config: {prog_config}")
self.success_log(
f"basline predictor_config: {self.inference_config_str(base_config)}"
)
for pred_config, (atol, rtol) in self.sample_predictor_configs( for pred_config, (atol, rtol) in self.sample_predictor_configs(
prog_config prog_config
...@@ -291,11 +293,7 @@ class MkldnnAutoScanTest(AutoScanTest): ...@@ -291,11 +293,7 @@ class MkldnnAutoScanTest(AutoScanTest):
== IgnoreReasons.MKLDNN_ACCURACY_ERROR == IgnoreReasons.MKLDNN_ACCURACY_ERROR
): ):
self.ignore_log( self.ignore_log(
"[MKLDNN_ACCURACY_ERROR] " f"[MKLDNN_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
+ ignore_info[2]
+ ' '
+ ' vs '
+ self.inference_config_str(pred_config)
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -315,28 +313,29 @@ class MkldnnAutoScanTest(AutoScanTest): ...@@ -315,28 +313,29 @@ class MkldnnAutoScanTest(AutoScanTest):
self.assert_tensors_near( self.assert_tensors_near(
atol, rtol, results[-1], results[0] atol, rtol, results[-1], results[0]
) )
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e: except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log( self.fail_log(
self.inference_config_str(pred_config) f"predictor_config: {self.inference_config_str(pred_config)}"
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
) )
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
if not ignore_flag: if not ignore_flag:
status = False status = False
continue continue
self.success_log(
'RUN predictor_config '
+ self.inference_config_str(pred_config)
+ ' done'
)
self.assertTrue(status) self.assertTrue(status)
def inference_config_str(self, config) -> str: def inference_config_str(self, config) -> str:
dic = {} dic = {}
enable_mkldnn = config.mkldnn_enabled() enable_mkldnn = config.mkldnn_enabled()
dic['use_mkldnn'] = enable_mkldnn dic["use_mkldnn"] = enable_mkldnn
enable_gpu = config.use_gpu() enable_gpu = config.use_gpu()
dic['use_gpu'] = enable_gpu dic["use_gpu"] = enable_gpu
return str(dic) return str(dic)
...@@ -351,7 +350,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -351,7 +350,7 @@ class PassAutoScanTest(AutoScanTest):
if pass_name not in self.available_passes_in_framework: if pass_name not in self.available_passes_in_framework:
continue continue
if not PassVersionChecker.IsCompatible(pass_name): if not PassVersionChecker.IsCompatible(pass_name):
self.fail_log(f'{pass_name} version check failed.') self.fail_log(f"{pass_name} version check failed.")
status = False status = False
return status return status
...@@ -368,9 +367,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -368,9 +367,7 @@ class PassAutoScanTest(AutoScanTest):
) )
if not os.path.exists(last_passed_program): if not os.path.exists(last_passed_program):
raise ValueError( raise ValueError(
"Cannot find file {}, please make sure that your pass name is correct".format( f"Cannot find file {last_passed_program}, please make sure that your pass name is correct"
last_passed_program
)
) )
model_bytes = paddle.static.load_from_file(last_passed_program) model_bytes = paddle.static.load_from_file(last_passed_program)
pg = paddle.static.deserialize_program(model_bytes) pg = paddle.static.deserialize_program(model_bytes)
...@@ -382,9 +379,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -382,9 +379,7 @@ class PassAutoScanTest(AutoScanTest):
after_op_list.append(main_block.op(i).type()) after_op_list.append(main_block.op(i).type())
self.assertTrue( self.assertTrue(
op_list_after_fusion == after_op_list, op_list_after_fusion == after_op_list,
"Expected operator list after fusion is {}, but now it's {}".format( f"Expected operator list after fusion is {op_list_after_fusion}, but now it's {after_op_list}",
op_list_after_fusion, after_op_list
),
) )
def run_and_statis( def run_and_statis(
...@@ -396,10 +391,10 @@ class PassAutoScanTest(AutoScanTest): ...@@ -396,10 +391,10 @@ class PassAutoScanTest(AutoScanTest):
max_duration=180, max_duration=180,
passes=None, passes=None,
): ):
if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev": if os.getenv("HYPOTHESIS_TEST_PROFILE", "ci") == "dev":
max_examples *= 10 max_examples *= 10
min_success_num *= 10 min_success_num *= 10
# while at ce phase, there's no limit on time # while at ce phase, there"s no limit on time
max_duration = -1 max_duration = -1
start_time = time.time() start_time = time.time()
settings.register_profile( settings.register_profile(
...@@ -431,13 +426,11 @@ class PassAutoScanTest(AutoScanTest): ...@@ -431,13 +426,11 @@ class PassAutoScanTest(AutoScanTest):
loop_func = reproduce(loop_func) loop_func = reproduce(loop_func)
logging.info(f"Start to running test of {type(self)}") logging.info(f"Start to running test of {type(self)}")
loop_func() loop_func()
logging.info( self.info_log(
"===================Statistical Information===================" "===================Statistical Information==================="
) )
logging.info( self.info_log(
"Number of Generated Programs: {}".format( f"Number of Generated Programs: {self.num_ran_programs + self.num_invalid_programs}"
self.num_ran_programs + self.num_invalid_programs
)
) )
logging.info(f"Number of Invalid Programs: {self.num_invalid_programs}") logging.info(f"Number of Invalid Programs: {self.num_invalid_programs}")
logging.info(f"Number of Ran Programs: {self.num_ran_programs}") logging.info(f"Number of Ran Programs: {self.num_ran_programs}")
...@@ -446,27 +439,21 @@ class PassAutoScanTest(AutoScanTest): ...@@ -446,27 +439,21 @@ class PassAutoScanTest(AutoScanTest):
self.num_ran_programs self.num_ran_programs
- self.num_ignore_tests / max(self.num_predictor_kinds, 1) - self.num_ignore_tests / max(self.num_predictor_kinds, 1)
) )
logging.info( self.info_log(
"Number of successfully ran programs approximately equal to {}".format( f"Number of successfully ran programs approximately equal to {successful_ran_programs}"
successful_ran_programs
)
) )
if successful_ran_programs < min_success_num: if successful_ran_programs < min_success_num:
logging.warning( self.fail_log(
"satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds" "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
) )
logging.error( self.fail_log(
"At least {} programs need to ran successfully, but now only about {} programs satisfied.".format( f"At least {min_success_num} programs need to ran successfully, but now only about {successful_ran_programs} programs satisfied."
min_success_num, successful_ran_programs
)
) )
raise AssertionError() raise AssertionError()
used_time = time.time() - start_time used_time = time.time() - start_time
if max_duration > 0 and used_time > max_duration: if max_duration > 0 and used_time > max_duration:
logging.error( self.fail_log(
"The duration exceeds {} seconds, if this is necessary, try to set a larger number for parameter `max_duration`.".format( f"The duration exceeds {max_duration} seconds, if this is necessary, try to set a larger number for parameter `max_duration`."
max_duration
)
) )
raise AssertionError() raise AssertionError()
...@@ -486,11 +473,10 @@ class PassAutoScanTest(AutoScanTest): ...@@ -486,11 +473,10 @@ class PassAutoScanTest(AutoScanTest):
feed_data = {} feed_data = {}
for name, tensor_config in prog_config.inputs.items(): for name, tensor_config in prog_config.inputs.items():
feed_data[name] = { feed_data[name] = {
'data': tensor_config.data, "data": tensor_config.data,
'lod': tensor_config.lod, "lod": tensor_config.lod,
} }
logging.info('RUN program_config: ' + str(prog_config))
self.num_predictor_kinds = 0 self.num_predictor_kinds = 0
for ( for (
pred_config, pred_config,
...@@ -507,11 +493,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -507,11 +493,7 @@ class PassAutoScanTest(AutoScanTest):
self.num_ignore_tests += 1 self.num_ignore_tests += 1
if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR: if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
self.ignore_log( self.ignore_log(
"[PASS_ACCURACY_ERROR] " f"[PASS_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
+ ignore_info[2]
+ ' '
+ ' vs '
+ self.inference_config_str(pred_config)
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -532,9 +514,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -532,9 +514,7 @@ class PassAutoScanTest(AutoScanTest):
model, params, prog_config, base_config, feed_data model, params, prog_config, base_config, feed_data
) )
self.success_log( self.success_log(
'RUN_BASELINE ' f"baseline program_config: {self.inference_config_str(base_config)}"
+ self.inference_config_str(base_config)
+ ' done'
) )
if os.path.exists(self.cache_dir): if os.path.exists(self.cache_dir):
...@@ -549,19 +529,19 @@ class PassAutoScanTest(AutoScanTest): ...@@ -549,19 +529,19 @@ class PassAutoScanTest(AutoScanTest):
if not ignore_flag: if not ignore_flag:
self.assert_op_list(op_list) self.assert_op_list(op_list)
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e: except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log( self.fail_log(
self.inference_config_str(pred_config) f"predictor_config: {self.inference_config_str(pred_config)}"
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
) )
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
if not ignore_flag: if not ignore_flag:
status = False status = False
continue continue
self.success_log(
'RUN predictor_config '
+ self.inference_config_str(pred_config)
+ ' done'
)
status = self.check_op_version() and status status = self.check_op_version() and status
self.assertTrue(status) self.assertTrue(status)
...@@ -569,23 +549,23 @@ class PassAutoScanTest(AutoScanTest): ...@@ -569,23 +549,23 @@ class PassAutoScanTest(AutoScanTest):
def inference_config_str(self, config) -> str: def inference_config_str(self, config) -> str:
dic = {} dic = {}
enable_mkldnn = config.mkldnn_enabled() enable_mkldnn = config.mkldnn_enabled()
dic['use_mkldnn'] = enable_mkldnn dic["use_mkldnn"] = enable_mkldnn
enable_gpu = config.use_gpu() enable_gpu = config.use_gpu()
dic['use_gpu'] = enable_gpu dic['use_gpu'] = enable_gpu
enable_xpu = config.use_xpu() enable_xpu = config.use_xpu()
dic['use_xpu'] = enable_xpu dic['use_xpu'] = enable_xpu
if not self.passes: if not self.passes:
dic['passes'] = self.passes dic["passes"] = self.passes
enable_trt = config.tensorrt_engine_enabled() enable_trt = config.tensorrt_engine_enabled()
trt_precison = config.tensorrt_precision_mode() trt_precison = config.tensorrt_precision_mode()
trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
if enable_trt: if enable_trt:
dic['use_trt'] = True dic["use_trt"] = True
dic['trt_precision'] = trt_precison dic["trt_precision"] = trt_precison
dic['use_dynamic_shape'] = trt_dynamic_shape dic["use_dynamic_shape"] = trt_dynamic_shape
else: else:
dic['use_trt'] = False dic["use_trt"] = False
return str(dic) return str(dic)
def create_trt_inference_config(self) -> paddle_infer.Config: def create_trt_inference_config(self) -> paddle_infer.Config:
...@@ -599,9 +579,9 @@ class PassAutoScanTest(AutoScanTest): ...@@ -599,9 +579,9 @@ class PassAutoScanTest(AutoScanTest):
class TrtLayerAutoScanTest(AutoScanTest): class TrtLayerAutoScanTest(AutoScanTest):
class TensorRTParam: class TensorRTParam:
''' """
TensorRT subgraph engine parameters. TensorRT subgraph engine parameters.
''' """
def __init__( def __init__(
self, self,
...@@ -620,9 +600,9 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -620,9 +600,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.use_calib_mode = use_calib_mode self.use_calib_mode = use_calib_mode
class DynamicShapeParam: class DynamicShapeParam:
''' """
Prepare TensorRT subgraph engine dynamic shape parameters. Prepare TensorRT subgraph engine dynamic shape parameters.
''' """
def __init__( def __init__(
self, self,
...@@ -648,7 +628,7 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -648,7 +628,7 @@ class TrtLayerAutoScanTest(AutoScanTest):
) )
self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
self.num_percent_cases = float( self.num_percent_cases = float(
os.getenv('TEST_NUM_PERCENT_CASES', default='1.0') os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")
) )
# Use a separate random generator for skipping tests # Use a separate random generator for skipping tests
...@@ -682,6 +662,9 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -682,6 +662,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
) )
return config return config
def get_avalible_input_type(self) -> List[np.dtype]:
return [np.float32]
def assert_tensors_near( def assert_tensors_near(
self, self,
atol: float, atol: float,
...@@ -693,39 +676,32 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -693,39 +676,32 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.assertEqual( self.assertEqual(
baseline[key].shape, baseline[key].shape,
arr.shape, arr.shape,
'The output shapes are not equal, the baseline shape is ' f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
) )
np.testing.assert_allclose(baseline[key], arr, rtol=rtol, atol=atol) np.testing.assert_allclose(arr, baseline[key], rtol=rtol, atol=atol)
def assert_op_size(self, trt_engine_num, paddle_op_num): def assert_op_size(self, trt_engine_num, paddle_op_num):
last_passed_program = os.path.join( last_passed_program = os.path.join(
self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel' self.cache_dir, "transpose_flatten_concat_fuse_pass.pdmodel"
) )
model_bytes = paddle.static.load_from_file(last_passed_program) model_bytes = paddle.static.load_from_file(last_passed_program)
pg = paddle.static.deserialize_program(model_bytes) pg = paddle.static.deserialize_program(model_bytes)
main_block = pg.desc.block(0) main_block = pg.desc.block(0)
op_size = main_block.op_size() op_size = main_block.op_size()
op_types = [ op_types = [
main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size) main_block.op(i).type() == "tensorrt_engine" for i in range(op_size)
] ]
trt_engine_size = sum(op_types) trt_engine_size = sum(op_types)
paddle_op_size = op_size - trt_engine_size paddle_op_size = op_size - trt_engine_size
self.assertEqual( self.assertEqual(
trt_engine_num, trt_engine_num,
trt_engine_size, trt_engine_size,
'Expected trt_engine_num is {}, but got {}!'.format( f"Expected trt_engine_num is {trt_engine_num}, but got {trt_engine_size}!",
trt_engine_num, trt_engine_size
),
) )
self.assertEqual( self.assertEqual(
paddle_op_num, paddle_op_num,
paddle_op_size, paddle_op_size,
'Expected paddle_op_num is {}, but got {}!'.format( f"Expected paddle_op_num is {paddle_op_num}, but got {paddle_op_size}!",
paddle_op_num, paddle_op_size
),
) )
def inference_config_str(self, config: paddle_infer.Config) -> str: def inference_config_str(self, config: paddle_infer.Config) -> str:
...@@ -734,11 +710,11 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -734,11 +710,11 @@ class TrtLayerAutoScanTest(AutoScanTest):
trt_precison = config.tensorrt_precision_mode() trt_precison = config.tensorrt_precision_mode()
trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
if enable_trt: if enable_trt:
dic['use_trt'] = True dic["use_trt"] = True
dic['trt_precision'] = trt_precison dic["trt_precision"] = trt_precison
dic['use_dynamic_shape'] = trt_dynamic_shape dic["use_dynamic_shape"] = trt_dynamic_shape
else: else:
dic['use_trt'] = False dic["use_trt"] = False
return str(dic) return str(dic)
def run_test(self, quant=False, skip_baseline=False, *args, **kwargs): def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
...@@ -765,110 +741,103 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -765,110 +741,103 @@ class TrtLayerAutoScanTest(AutoScanTest):
feed_data = {} feed_data = {}
for name, tensor_config in prog_config.inputs.items(): for name, tensor_config in prog_config.inputs.items():
feed_data[name] = { feed_data[name] = {
'data': tensor_config.data, "data": tensor_config.data,
'lod': tensor_config.lod, "lod": tensor_config.lod,
} }
results: List[Dict[str, np.ndarray]] = []
if not skip_baseline: if not skip_baseline:
# baseline: gpu run # baseline: gpu run, we only test float32
logging.info('RUN program_config: ' + str(prog_config))
gpu_config = self.create_inference_config(use_trt=False) gpu_config = self.create_inference_config(use_trt=False)
results.append( baseline_result = self.run_test_config(
self.run_test_config( model,
model, params, prog_config, gpu_config, feed_data params,
) prog_config.set_input_type(np.float32),
gpu_config,
feed_data,
) )
self.success_log('RUN_GPU_BASELINE done') self.success_log(f"basline program_config: {prog_config}")
for ( for (
pred_config, pred_config,
nodes_num, nodes_num,
threshold, threshold,
) in self.sample_predictor_configs(prog_config): ) in self.sample_predictor_configs(prog_config):
for input_type in self.get_avalible_input_type():
prog_config = prog_config.set_input_type(input_type)
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
if os.path.exists(self.cache_dir): if isinstance(threshold, float):
shutil.rmtree(self.cache_dir) atol = threshold
rtol = 1e-8
if isinstance(threshold, float): elif isinstance(threshold, list) or isinstance(
atol = threshold threshold, tuple
rtol = 1e-8 ):
elif isinstance(threshold, (list, tuple)): atol = threshold[0]
atol = threshold[0] rtol = threshold[1]
rtol = threshold[1] else:
else: raise NotImplementedError
raise NotImplementedError
is_fp8 = (
if ( pred_config.tensorrt_precision_mode()
pred_config.tensorrt_precision_mode() == paddle_infer.PrecisionType.Int8
!= paddle_infer.PrecisionType.Int8 )
and quant if (not is_fp8 and quant) or (is_fp8 and not quant):
): continue
continue
if ( ignore_flag = False
pred_config.tensorrt_precision_mode() for teller, reason, note in self.ignore_cases:
== paddle_infer.PrecisionType.Int8 if teller(prog_config, pred_config):
and not quant ignore_flag = True
): if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
continue self.ignore_log(
f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}"
ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
ignore_flag = True
if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
self.ignore_log(
'[TRT_NOT_IMPLEMENTED] {} vs {}'.format(
note, self.inference_config_str(pred_config)
) )
) elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
elif reason == IgnoreReasons.TRT_NOT_SUPPORT: self.ignore_log(
self.ignore_log( f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}"
'[TRT_NOT_SUPPORT] {} vs {}'.format(
note, self.inference_config_str(pred_config)
) )
) else:
else: raise NotImplementedError
raise NotImplementedError break
break
if ignore_flag: if ignore_flag:
continue continue
try: try:
pred_config_deserialize = paddle_infer.Config(pred_config) pred_config_deserialize = paddle_infer.Config(
results.append( pred_config
self.run_test_config( )
trt_result = self.run_test_config(
model, params, prog_config, pred_config, feed_data model, params, prog_config, pred_config, feed_data
) )
) self.assert_tensors_near(
self.assert_tensors_near( atol, rtol, trt_result, baseline_result
atol, rtol, results[-1], results[0]
)
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)
# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
) )
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)
# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
)
self.success_log( self.success_log(f"program_config: {prog_config}")
'RUN predictor_config {} done'.format( self.success_log(
self.inference_config_str(pred_config) f"predictor_config: {self.inference_config_str(pred_config)}"
) )
) except Exception as e:
except Exception as e: self.fail_log(f"program_config: {prog_config}")
self.fail_log( self.fail_log(
self.inference_config_str(pred_config) f"predictor_config: {self.inference_config_str(pred_config)}"
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m' )
) self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
all_passes = False all_passes = False
self.assertTrue(all_passes) self.assertTrue(all_passes)
......
...@@ -54,8 +54,8 @@ class TensorConfig: ...@@ -54,8 +54,8 @@ class TensorConfig:
if data_gen is not None: if data_gen is not None:
self.data_gen = data_gen self.data_gen = data_gen
self.data = data_gen() self.data = data_gen()
self.dtype = data_gen().dtype self.dtype = self.data.dtype
self.shape = data_gen().shape self.shape = self.data.shape
else: else:
assert ( assert (
shape is not None shape is not None
...@@ -67,6 +67,11 @@ class TensorConfig: ...@@ -67,6 +67,11 @@ class TensorConfig:
def __repr__(self): def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
def astype(self, type: np.dtype):
self.data = self.data.astype(type)
self.dtype = self.data.dtype
return self
class VarType(enum.Enum): class VarType(enum.Enum):
LOD_TENSOR = 1 LOD_TENSOR = 1
...@@ -270,6 +275,16 @@ class ProgramConfig: ...@@ -270,6 +275,16 @@ class ProgramConfig:
return log_str return log_str
def set_input_type(self, type: np.dtype):
for inp in self.inputs.values():
inp.astype(type)
for weight in self.weights.values():
weight.astype(type)
return self
def get_input_type(self) -> np.dtype:
return next(iter(self.inputs.values())).dtype
def create_fake_model(program_config): def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.''' '''Create a Paddle model(in memory) according to the given config.'''
......
...@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
def generate_elewise_input(): def generate_elewise_input():
return np.random.random([1, 12, 128, 128]).astype(np.float32) return np.random.random([1, 12, 128, 128]).astype(np.float32)
def generate_weight(shape):
return np.random.random(shape).astype(np.float32)
mul_0 = OpConfig( mul_0 = OpConfig(
"mul", "mul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]}, inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
...@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
), ),
}, },
weights={ weights={
"mul_0_w": TensorConfig(shape=[768, 768]), "mul_0_w": TensorConfig(
"mul_1_w": TensorConfig(shape=[768, 768]), data_gen=partial(generate_weight, [768, 768])
"mul_2_w": TensorConfig(shape=[768, 768]), ),
"mul_3_w": TensorConfig(shape=[768, 768]), "mul_1_w": TensorConfig(
"ele_0_w": TensorConfig(shape=[768]), data_gen=partial(generate_weight, [768, 768])
"ele_1_w": TensorConfig(shape=[768]), ),
"ele_2_w": TensorConfig(shape=[768]), "mul_2_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_3_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"ele_0_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_1_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_2_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
}, },
outputs=[ops[-1].outputs["Out"][0]], outputs=[ops[-1].outputs["Out"][0]],
) )
......
...@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
ver = paddle_infer.get_trt_compile_version() ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version >= 8400: if trt_version >= 8400:
if self.dims == 1 and not dynamic_shape: if self.dims == 1:
return 0, 3 return 0, 3
return 1, 2 return 1, 2
else: else:
if (self.dims == 1 and not dynamic_shape) or ( if self.dims <= 2 or (
program_config.inputs['input_data'].dtype program_config.inputs['input_data'].dtype
in ['bool', 'int8', 'uint8'] in ['bool', 'int8', 'uint8']
): ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册