未验证 提交 18934c53 编写于 作者: W Wilber 提交者: GitHub

update trt ut. (#35458)

上级 ffc3d364
...@@ -378,7 +378,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -378,7 +378,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// \return bool Whether the trt dynamic_shape is used. /// \return bool Whether the trt dynamic_shape is used.
/// ///
bool tensorrt_dynamic_shape_enabled() const { bool tensorrt_dynamic_shape_enabled() const {
return min_input_shape_.empty(); return !min_input_shape_.empty();
} }
/// ///
/// \brief Prevent ops running in Paddle-TRT /// \brief Prevent ops running in Paddle-TRT
......
...@@ -35,14 +35,11 @@ class SkipReasons(enum.Enum): ...@@ -35,14 +35,11 @@ class SkipReasons(enum.Enum):
TRT_NOT_IMPLEMENTED = 0 TRT_NOT_IMPLEMENTED = 0
# TRT not support. # TRT not support.
TRT_NOT_SUPPORT = 1 TRT_NOT_SUPPORT = 1
# Implement wrong.
ALGO_WRONG = 2
# Quant model, only to run in INT8 mode.
QUANT_MODEL = 3
class AutoScanTest(unittest.TestCase): class AutoScanTest(unittest.TestCase):
def __init__(self, methodName='runTest'): def __init__(self, methodName='runTest'):
np.random.seed(1024)
paddle.enable_static() paddle.enable_static()
super(AutoScanTest, self).__init__(methodName) super(AutoScanTest, self).__init__(methodName)
self.skip_cases = [] self.skip_cases = []
...@@ -68,7 +65,7 @@ class AutoScanTest(unittest.TestCase): ...@@ -68,7 +65,7 @@ class AutoScanTest(unittest.TestCase):
self.skip_cases.append((teller, reason, note)) self.skip_cases.append((teller, reason, note))
@abc.abstractmethod @abc.abstractmethod
def check_program_validity(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
raise NotImplementedError raise NotImplementedError
def run_test_config(self, model, params, prog_config, pred_config, def run_test_config(self, model, params, prog_config, pred_config,
......
...@@ -78,7 +78,17 @@ class ProgramConfig: ...@@ -78,7 +78,17 @@ class ProgramConfig:
inputs: Dict[str, TensorConfig], inputs: Dict[str, TensorConfig],
outputs: List[str]): outputs: List[str]):
self.ops = ops self.ops = ops
self.weights = weights # if no weight need to save, we create a place_holder to help seriazlie params.
if not weights:
def generate_weight():
return np.array([1]).astype(np.float32)
self.weights = {
"place_holder_weight": TensorConfig(data_gen=generate_weight)
}
else:
self.weights = weights
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
......
...@@ -21,7 +21,7 @@ from typing import Optional, List, Callable, Dict, Any, Set ...@@ -21,7 +21,7 @@ from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertConv2dTest(TrtLayerAutoScanTest): class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def check_program_validity(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
# TODO: This is just the example to remove the wrong attrs. # TODO: This is just the example to remove the wrong attrs.
inputs = program_config.inputs inputs = program_config.inputs
weights = program_config.weights weights = program_config.weights
...@@ -51,19 +51,19 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -51,19 +51,19 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([24, 3, 3, 3]).astype(np.float32) return np.random.random([24, 3, 3, 3]).astype(np.float32)
# for strides in [[1,1], [2,2]]: # for strides in [[1, 1], [2, 2], [1, 2], [2, 3]]:
# for paddings in [[0,3], [3,1]]: # for paddings in [[0, 3], [3, 1], [1, 1, 1, 1]]:
# for groups in [1]: # for groups in [1, 2]:
# for padding_algotithm in ['EXPLICIT']: # for padding_algotithm in ['EXPLICIT', 'SAME', 'VALID']:
# for dilations in [[1,1]]: # for dilations in [[1, 1], [1, 2]]:
# for data_format in ['NCHW']: # for data_format in ['NCHW']:
for strides in [[1, 1], [2, 2]]:
for strides in [[1, 1], [2, 2], [1, 2], [2, 3]]: for paddings in [[0, 3], [3, 1]]:
for paddings in [[0, 3], [3, 1], [1, 1, 1, 1], [2, 1, 1, 3]]: for groups in [1]:
for groups in [1, 2]: for padding_algotithm in ['EXPLICIT']:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VALID']: for dilations in [[1, 1]]:
for dilations in [[1, 1], [1, 2]]:
for data_format in ['NCHW']: for data_format in ['NCHW']:
dics = [{ dics = [{
"data_fromat": data_format, "data_fromat": data_format,
"dilations": dilations, "dilations": dilations,
...@@ -110,11 +110,6 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -110,11 +110,6 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
}, },
outputs=["relu_output_data"]) outputs=["relu_output_data"])
# if config is invalid, we should skip that cases.
if not self.check_program_validity(
program_config):
continue
yield program_config yield program_config
def sample_predictor_configs( def sample_predictor_configs(
...@@ -144,10 +139,15 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -144,10 +139,15 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
"input_data": [1, 3, 64, 64] "input_data": [1, 3, 64, 64]
} }
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
# TODO: This is just the example, need to be fixed. # TODO: This is just the example, need to be fixed.
if len(attrs[0]['paddings']) == 4: if len(attrs[0]['paddings']) == 4:
return 0, 3 return 1, 2
else: else:
return 1, 2 return 1, 2
...@@ -157,6 +157,7 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -157,6 +157,7 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
] ]
# for static_shape # for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5 attrs, False), 1e-5
...@@ -182,25 +183,16 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -182,25 +183,16 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def add_skip_trt_case(self): def add_skip_trt_case(self):
# TODO(wilber): This is just the example to illustrate the skip usage. # TODO(wilber): This is just the example to illustrate the skip usage.
def teller1(program_config, predictor_config): def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['groups'] == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.ALGO_WRONG,
"Need to repair the case: ......TODO, just for the example")
def teller2(program_config, predictor_config):
if len(program_config.ops[0].attrs['paddings']) == 4: if len(program_config.ops[0].attrs['paddings']) == 4:
return True return True
return False return False
self.add_skip_case( self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED, teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"NOT Implemented: we need to add support in the future ....TODO, just for the example" "NOT Implemented: we need to add support in the future ....TODO, just for the example"
) )
def teller3(program_config, predictor_config): def teller2(program_config, predictor_config):
if ( if (
program_config.ops[0].attrs['dilations'][0] == 1 and program_config.ops[0].attrs['dilations'][0] == 1 and
program_config.ops[0].attrs['dilations'][0] == 2 program_config.ops[0].attrs['dilations'][0] == 2
...@@ -208,19 +200,9 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -208,19 +200,9 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
return True return True
return False return False
self.add_skip_case(teller3, SkipReasons.TRT_NOT_SUPPORT, self.add_skip_case(teller2, SkipReasons.TRT_NOT_SUPPORT,
"TODO, just for the example")
def teller4(program_config, predictor_config):
if program_config.ops[0].attrs['strides'][0] != program_config.ops[
0].attrs['strides'][1] or program_config.ops[0].attrs[
'strides'][0] == program_config.ops[0].attrs['strides'][
1] == 2:
return True
return False
self.add_skip_case(teller4, SkipReasons.TRT_NOT_SUPPORT,
"TODO, just for the example") "TODO, just for the example")
pass
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
......
...@@ -51,11 +51,11 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -51,11 +51,11 @@ class TrtLayerAutoScanTest(AutoScanTest):
Prepare TensorRT subgraph engine dynamic shape parameters. Prepare TensorRT subgraph engine dynamic shape parameters.
''' '''
def __init__(self, min_input_shape, max_input_shape, optim_input_shape, def __init__(self, min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16): disable_trt_plugin_fp16):
self.min_input_shape = min_input_shape self.min_input_shape = min_input_shape
self.max_input_shape = max_input_shape self.max_input_shape = max_input_shape
self.optim_input_shape = optim_input_shape self.opt_input_shape = opt_input_shape
self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16 self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16
def __init__(self, methodName='runTest'): def __init__(self, methodName='runTest'):
...@@ -161,28 +161,13 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -161,28 +161,13 @@ class TrtLayerAutoScanTest(AutoScanTest):
return str(dic) return str(dic)
def run_test(self, quant=False): def run_test(self, quant=False):
if quant: status = True
def teller(program_config, predictor_config):
if predictor_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8:
return False
return True
self.add_skip_case(teller, SkipReasons.QUANT_MODEL,
"Only test QUANT model")
else:
def teller(program_config, predictor_config):
if predictor_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(teller, SkipReasons.QUANT_MODEL,
"Not test QUANT model")
for prog_config in self.sample_program_configs(): for prog_config in self.sample_program_configs():
# if program is invalid, we should skip that cases.
if not self.is_program_valid(prog_config):
continue
model, params = create_fake_model(prog_config) model, params = create_fake_model(prog_config)
if quant: if quant:
model, params = create_quant_model(model, params) model, params = create_quant_model(model, params)
...@@ -206,15 +191,18 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -206,15 +191,18 @@ class TrtLayerAutoScanTest(AutoScanTest):
for pred_config, nodes_num, threshold in self.sample_predictor_configs( for pred_config, nodes_num, threshold in self.sample_predictor_configs(
prog_config): prog_config):
if quant and pred_config.tensorrt_precision_mode(
) != paddle_infer.PrecisionType.Int8:
continue
if pred_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8 and not quant:
continue
skip_flag = False skip_flag = False
for skip_info in self.skip_cases: for skip_info in self.skip_cases:
if skip_info[0](prog_config, pred_config): if skip_info[0](prog_config, pred_config):
skip_flag = True skip_flag = True
if skip_info[1] == SkipReasons.ALGO_WRONG: if skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED:
self.skip_log("[ALGO_WRONG] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config))
elif skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED:
self.skip_log("[TRT_NOT_IMPLEMENTED] " + skip_info[ self.skip_log("[TRT_NOT_IMPLEMENTED] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self. 2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config)) inference_config_str(pred_config))
...@@ -222,24 +210,28 @@ class TrtLayerAutoScanTest(AutoScanTest): ...@@ -222,24 +210,28 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.skip_log("[TRT_NOT_SUPPORT] " + skip_info[ self.skip_log("[TRT_NOT_SUPPORT] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self. 2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config)) inference_config_str(pred_config))
elif skip_info[1] == SkipReasons.QUANT_MODEL:
pass
else: else:
raise NotImplementedError raise NotImplementedError
if skip_flag: break
continue
try: try:
results.append( results.append(
self.run_test_config(model, params, prog_config, self.run_test_config(model, params, prog_config,
pred_config, feed_data)) pred_config, feed_data))
self.assert_tensors_near(threshold, results[-1], results[0]) self.assert_tensors_near(threshold, results[-1], results[0])
self.assert_op_size(nodes_num[0], nodes_num[1])
except Exception as e: except Exception as e:
self.fail_log( self.fail_log(
str(prog_config) + ' vs ' + self.inference_config_str( str(prog_config) + ' vs ' + self.inference_config_str(
pred_config) + str(e)) pred_config) +
'\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
status = False
continue continue
if not skip_flag:
self.assert_op_size(nodes_num[0], nodes_num[1])
self.success_log('RUN ' + str(prog_config) + ' vs ' + self.success_log('RUN ' + str(prog_config) + ' vs ' +
self.inference_config_str(pred_config)) self.inference_config_str(pred_config))
# In the first step, we found the problem, and after the subsequent repairs, the assert assertion will be enabled
# self.assertTrue(status)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册