diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py index 6957a4ceb26deab9036db08cdc02ff6e10294a33..3ac185fbb04aca5a278ea50a86630e421ba62e49 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -25,6 +25,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.inference as paddle_infer +import shutil from paddle import compat as cpt from typing import Optional, List, Callable, Dict, Any, Set @@ -68,18 +69,21 @@ class TrtLayerAutoScanTest(AutoScanTest): max_batch_size=4, min_subgraph_size=0, precision=paddle_infer.PrecisionType.Float32, - use_static=False, + use_static=True, use_calib_mode=False) self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) self.num_percent_cases = float( os.getenv( 'TEST_NUM_PERCENT_CASES', default='1.0')) + abs_dir = os.path.abspath(os.path.dirname(__file__)) + cache_dir = str(self.__module__) + '_trt_cache_dir' + self.trt_cache_dir = os.path.join(abs_dir, cache_dir) def create_inference_config(self, use_trt=True) -> paddle_infer.Config: config = paddle_infer.Config() - config.disable_glog_info() + # config.disable_glog_info() config.enable_use_gpu(100, 0) - config.set_optim_cache_dir('trt_convert_cache_dir') + config.set_optim_cache_dir(self.trt_cache_dir) if use_trt: config.switch_ir_debug() config.enable_tensorrt_engine( @@ -218,6 +222,9 @@ class TrtLayerAutoScanTest(AutoScanTest): for pred_config, nodes_num, threshold in self.sample_predictor_configs( prog_config): + if os.path.exists(self.trt_cache_dir): + shutil.rmtree(self.trt_cache_dir) + if isinstance(threshold, float): atol = threshold rtol = 1e-8 @@ -261,9 +268,9 @@ class TrtLayerAutoScanTest(AutoScanTest): if not skip_flag: self.assert_op_size(nodes_num[0], nodes_num[1]) # deserialize test - #if nodes_num[0] > 0: - # self.run_test_config(model, params, prog_config, - # pred_config_deserialize, feed_data) + if nodes_num[0] > 0: + self.run_test_config(model, params, prog_config, + pred_config_deserialize, feed_data) except Exception as e: self.fail_log( str(prog_config) + ' vs ' + self.inference_config_str(