未验证 提交 ba71421c 编写于 作者: W Wilber 提交者: GitHub

trt support serialize and deserialize (#35828)

上级 2fff5a58
...@@ -25,6 +25,7 @@ import paddle ...@@ -25,6 +25,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
import shutil
from paddle import compat as cpt from paddle import compat as cpt
from typing import Optional, List, Callable, Dict, Any, Set from typing import Optional, List, Callable, Dict, Any, Set
...@@ -68,18 +69,21 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -68,18 +69,21 @@ class TrtLayerAutoScanTest(AutoScanTest):
max_batch_size=4, max_batch_size=4,
min_subgraph_size=0, min_subgraph_size=0,
precision=paddle_infer.PrecisionType.Float32, precision=paddle_infer.PrecisionType.Float32,
use_static=False, use_static=True,
use_calib_mode=False) use_calib_mode=False)
self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
self.num_percent_cases = float( self.num_percent_cases = float(
os.getenv( os.getenv(
'TEST_NUM_PERCENT_CASES', default='1.0')) 'TEST_NUM_PERCENT_CASES', default='1.0'))
abs_dir = os.path.abspath(os.path.dirname(__file__))
cache_dir = str(self.__module__) + '_trt_cache_dir'
self.trt_cache_dir = os.path.join(abs_dir, cache_dir)
def create_inference_config(self, use_trt=True) -> paddle_infer.Config: def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
config = paddle_infer.Config() config = paddle_infer.Config()
config.disable_glog_info() # config.disable_glog_info()
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
config.set_optim_cache_dir('trt_convert_cache_dir') config.set_optim_cache_dir(self.trt_cache_dir)
if use_trt: if use_trt:
config.switch_ir_debug() config.switch_ir_debug()
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
...@@ -218,6 +222,9 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -218,6 +222,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
for pred_config, nodes_num, threshold in self.sample_predictor_configs( for pred_config, nodes_num, threshold in self.sample_predictor_configs(
prog_config): prog_config):
if os.path.exists(self.trt_cache_dir):
shutil.rmtree(self.trt_cache_dir)
if isinstance(threshold, float): if isinstance(threshold, float):
atol = threshold atol = threshold
rtol = 1e-8 rtol = 1e-8
...@@ -261,9 +268,9 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -261,9 +268,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
if not skip_flag: if not skip_flag:
self.assert_op_size(nodes_num[0], nodes_num[1]) self.assert_op_size(nodes_num[0], nodes_num[1])
# deserialize test # deserialize test
#if nodes_num[0] > 0: if nodes_num[0] > 0:
# self.run_test_config(model, params, prog_config, self.run_test_config(model, params, prog_config,
# pred_config_deserialize, feed_data) pred_config_deserialize, feed_data)
except Exception as e: except Exception as e:
self.fail_log( self.fail_log(
str(prog_config) + ' vs ' + self.inference_config_str( str(prog_config) + ' vs ' + self.inference_config_str(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册