未验证 提交 18934c53 编写于 作者: W Wilber 提交者: GitHub

update trt ut. (#35458)

上级 ffc3d364
......@@ -378,7 +378,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// \return bool Whether the trt dynamic_shape is used.
///
bool tensorrt_dynamic_shape_enabled() const {
return min_input_shape_.empty();
return !min_input_shape_.empty();
}
///
/// \brief Prevent ops running in Paddle-TRT
......
......@@ -35,14 +35,11 @@ class SkipReasons(enum.Enum):
TRT_NOT_IMPLEMENTED = 0
# TRT not support.
TRT_NOT_SUPPORT = 1
# Implement wrong.
ALGO_WRONG = 2
# Quant model, only to run in INT8 mode.
QUANT_MODEL = 3
class AutoScanTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
np.random.seed(1024)
paddle.enable_static()
super(AutoScanTest, self).__init__(methodName)
self.skip_cases = []
......@@ -68,7 +65,7 @@ class AutoScanTest(unittest.TestCase):
self.skip_cases.append((teller, reason, note))
@abc.abstractmethod
def check_program_validity(self, program_config: ProgramConfig) -> bool:
def is_program_valid(self, program_config: ProgramConfig) -> bool:
raise NotImplementedError
def run_test_config(self, model, params, prog_config, pred_config,
......
......@@ -78,7 +78,17 @@ class ProgramConfig:
inputs: Dict[str, TensorConfig],
outputs: List[str]):
self.ops = ops
self.weights = weights
# if no weight need to save, we create a place_holder to help seriazlie params.
if not weights:
def generate_weight():
return np.array([1]).astype(np.float32)
self.weights = {
"place_holder_weight": TensorConfig(data_gen=generate_weight)
}
else:
self.weights = weights
self.inputs = inputs
self.outputs = outputs
......
......@@ -21,7 +21,7 @@ from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def check_program_validity(self, program_config: ProgramConfig) -> bool:
def is_program_valid(self, program_config: ProgramConfig) -> bool:
# TODO: This is just the example to remove the wrong attrs.
inputs = program_config.inputs
weights = program_config.weights
......@@ -51,19 +51,19 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([24, 3, 3, 3]).astype(np.float32)
# for strides in [[1,1], [2,2]]:
# for paddings in [[0,3], [3,1]]:
# for groups in [1]:
# for padding_algotithm in ['EXPLICIT']:
# for dilations in [[1,1]]:
# for strides in [[1, 1], [2, 2], [1, 2], [2, 3]]:
# for paddings in [[0, 3], [3, 1], [1, 1, 1, 1]]:
# for groups in [1, 2]:
# for padding_algotithm in ['EXPLICIT', 'SAME', 'VALID']:
# for dilations in [[1, 1], [1, 2]]:
# for data_format in ['NCHW']:
for strides in [[1, 1], [2, 2], [1, 2], [2, 3]]:
for paddings in [[0, 3], [3, 1], [1, 1, 1, 1], [2, 1, 1, 3]]:
for groups in [1, 2]:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VALID']:
for dilations in [[1, 1], [1, 2]]:
for strides in [[1, 1], [2, 2]]:
for paddings in [[0, 3], [3, 1]]:
for groups in [1]:
for padding_algotithm in ['EXPLICIT']:
for dilations in [[1, 1]]:
for data_format in ['NCHW']:
dics = [{
"data_fromat": data_format,
"dilations": dilations,
......@@ -110,11 +110,6 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
},
outputs=["relu_output_data"])
# if config is invalid, we should skip that cases.
if not self.check_program_validity(
program_config):
continue
yield program_config
def sample_predictor_configs(
......@@ -144,10 +139,15 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
"input_data": [1, 3, 64, 64]
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
# TODO: This is just the example, need to be fixed.
if len(attrs[0]['paddings']) == 4:
return 0, 3
return 1, 2
else:
return 1, 2
......@@ -157,6 +157,7 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
......@@ -182,25 +183,16 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
def add_skip_trt_case(self):
# TODO(wilber): This is just the example to illustrate the skip usage.
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['groups'] == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.ALGO_WRONG,
"Need to repair the case: ......TODO, just for the example")
def teller2(program_config, predictor_config):
if len(program_config.ops[0].attrs['paddings']) == 4:
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"NOT Implemented: we need to add support in the future ....TODO, just for the example"
)
def teller3(program_config, predictor_config):
def teller2(program_config, predictor_config):
if (
program_config.ops[0].attrs['dilations'][0] == 1 and
program_config.ops[0].attrs['dilations'][0] == 2
......@@ -208,19 +200,9 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
return True
return False
self.add_skip_case(teller3, SkipReasons.TRT_NOT_SUPPORT,
"TODO, just for the example")
def teller4(program_config, predictor_config):
if program_config.ops[0].attrs['strides'][0] != program_config.ops[
0].attrs['strides'][1] or program_config.ops[0].attrs[
'strides'][0] == program_config.ops[0].attrs['strides'][
1] == 2:
return True
return False
self.add_skip_case(teller4, SkipReasons.TRT_NOT_SUPPORT,
self.add_skip_case(teller2, SkipReasons.TRT_NOT_SUPPORT,
"TODO, just for the example")
pass
def test(self):
self.add_skip_trt_case()
......
......@@ -51,11 +51,11 @@ class TrtLayerAutoScanTest(AutoScanTest):
Prepare TensorRT subgraph engine dynamic shape parameters.
'''
def __init__(self, min_input_shape, max_input_shape, optim_input_shape,
def __init__(self, min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16):
self.min_input_shape = min_input_shape
self.max_input_shape = max_input_shape
self.optim_input_shape = optim_input_shape
self.opt_input_shape = opt_input_shape
self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16
def __init__(self, methodName='runTest'):
......@@ -161,28 +161,13 @@ class TrtLayerAutoScanTest(AutoScanTest):
return str(dic)
def run_test(self, quant=False):
if quant:
def teller(program_config, predictor_config):
if predictor_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8:
return False
return True
self.add_skip_case(teller, SkipReasons.QUANT_MODEL,
"Only test QUANT model")
else:
def teller(program_config, predictor_config):
if predictor_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(teller, SkipReasons.QUANT_MODEL,
"Not test QUANT model")
status = True
for prog_config in self.sample_program_configs():
# if program is invalid, we should skip that cases.
if not self.is_program_valid(prog_config):
continue
model, params = create_fake_model(prog_config)
if quant:
model, params = create_quant_model(model, params)
......@@ -206,15 +191,18 @@ class TrtLayerAutoScanTest(AutoScanTest):
for pred_config, nodes_num, threshold in self.sample_predictor_configs(
prog_config):
if quant and pred_config.tensorrt_precision_mode(
) != paddle_infer.PrecisionType.Int8:
continue
if pred_config.tensorrt_precision_mode(
) == paddle_infer.PrecisionType.Int8 and not quant:
continue
skip_flag = False
for skip_info in self.skip_cases:
if skip_info[0](prog_config, pred_config):
skip_flag = True
if skip_info[1] == SkipReasons.ALGO_WRONG:
self.skip_log("[ALGO_WRONG] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config))
elif skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED:
if skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED:
self.skip_log("[TRT_NOT_IMPLEMENTED] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config))
......@@ -222,24 +210,28 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.skip_log("[TRT_NOT_SUPPORT] " + skip_info[
2] + ' ' + repr(prog_config) + ' vs ' + self.
inference_config_str(pred_config))
elif skip_info[1] == SkipReasons.QUANT_MODEL:
pass
else:
raise NotImplementedError
if skip_flag:
continue
break
try:
results.append(
self.run_test_config(model, params, prog_config,
pred_config, feed_data))
self.assert_tensors_near(threshold, results[-1], results[0])
self.assert_op_size(nodes_num[0], nodes_num[1])
except Exception as e:
self.fail_log(
str(prog_config) + ' vs ' + self.inference_config_str(
pred_config) + str(e))
pred_config) +
'\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
status = False
continue
if not skip_flag:
self.assert_op_size(nodes_num[0], nodes_num[1])
self.success_log('RUN ' + str(prog_config) + ' vs ' +
self.inference_config_str(pred_config))
# In the first step, we found the problem, and after the subsequent repairs, the assert assertion will be enabled
# self.assertTrue(status)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册