未验证 提交 645e81f0 编写于 作者: F Frank Lin 提交者: GitHub

Improve stablity of Paddle-TensorRT FP16 UT GitHub (1) (#51554)

* Improve Readability and Overall Clarity of Logging

* Adds the set_input_type API for specifying input data types

* Specifying input data types
上级 4b85e5db
......@@ -36,7 +36,8 @@ import paddle
import paddle.inference as paddle_infer
from paddle.fluid.core import PassVersionChecker
logging.basicConfig(level=logging.INFO, format="%(message)s")
LOGLEVEL = os.environ.get("PADDLE_TEST_LOGLEVEL", "INFO").upper()
logging.basicConfig(level=LOGLEVEL, format="%(message)s")
settings.register_profile(
"ci",
......@@ -57,8 +58,8 @@ settings.register_profile(
report_multiple_bugs=False,
)
if (
float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1
or os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci'
float(os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")) < 1
or os.getenv("HYPOTHESIS_TEST_PROFILE", "dev") == "ci"
):
settings.load_profile("ci")
else:
......@@ -100,10 +101,10 @@ class AutoScanTest(unittest.TestCase):
@abc.abstractmethod
def sample_program_configs(self):
'''
"""
Generate all config with the combination of different Input tensor shape and
different Attr values.
'''
"""
raise NotImplementedError
@abc.abstractmethod
......@@ -125,9 +126,9 @@ class AutoScanTest(unittest.TestCase):
def run_test_config(
self, model, params, prog_config, pred_config, feed_data
) -> Dict[str, np.ndarray]:
'''
"""
Test a single case.
'''
"""
pred_config.set_model_buffer(model, len(model), params, len(params))
predictor = paddle_infer.create_predictor(pred_config)
self.available_passes_in_framework = (
......@@ -136,9 +137,9 @@ class AutoScanTest(unittest.TestCase):
)
for name, _ in prog_config.inputs.items():
input_tensor = predictor.get_input_handle(name)
input_tensor.copy_from_cpu(feed_data[name]['data'])
if feed_data[name]['lod'] is not None:
input_tensor.set_lod(feed_data[name]['lod'])
input_tensor.copy_from_cpu(feed_data[name]["data"])
if feed_data[name]["lod"] is not None:
input_tensor.set_lod(feed_data[name]["lod"])
predictor.run()
result = {}
for out_name, o_name in zip(
......@@ -158,10 +159,7 @@ class AutoScanTest(unittest.TestCase):
for key, arr in tensor.items():
self.assertTrue(
baseline[key].shape == arr.shape,
"The output shapes are not equal, the baseline shape is "
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
)
diff = abs(baseline[key] - arr)
np.testing.assert_allclose(
......@@ -169,9 +167,7 @@ class AutoScanTest(unittest.TestCase):
arr,
rtol=rtol,
atol=atol,
err_msg='Output has diff, Maximum absolute error: {}'.format(
np.amax(diff)
),
err_msg=f"Output has diff, Maximum absolute error: {np.amax(diff)}",
)
@abc.abstractmethod
......@@ -207,15 +203,19 @@ class AutoScanTest(unittest.TestCase):
@abc.abstractmethod
def ignore_log(self, msg: str):
logging.warning("SKIP: " + msg)
logging.debug(f"SKIP: {msg}")
@abc.abstractmethod
def fail_log(self, msg: str):
logging.error("FAIL: " + msg)
logging.error(f"FAIL: {msg}")
@abc.abstractmethod
def info_log(self, msg: str):
logging.debug(f"INFO: {msg}")
@abc.abstractmethod
def success_log(self, msg: str):
logging.info("SUCCESS: " + msg)
logging.debug(f"SUCCESS: {msg}")
@abc.abstractmethod
def create_inference_config(
......@@ -263,20 +263,22 @@ class MkldnnAutoScanTest(AutoScanTest):
feed_data = {}
for name, tensor_config in prog_config.inputs.items():
feed_data[name] = {
'data': tensor_config.data,
'lod': tensor_config.lod,
"data": tensor_config.data,
"lod": tensor_config.lod,
}
results: List[Dict[str, np.ndarray]] = []
# baseline: cpu no ir_optim run
base_config = self.create_inference_config(ir_optim=False)
logging.info('RUN program_config: ' + str(prog_config))
results.append(
self.run_test_config(
model, params, prog_config, base_config, feed_data
)
)
self.success_log('RUN_CPU_BASELINE done')
self.success_log(f"basline program_config: {prog_config}")
self.success_log(
f"basline predictor_config: {self.inference_config_str(base_config)}"
)
for pred_config, (atol, rtol) in self.sample_predictor_configs(
prog_config
......@@ -291,11 +293,7 @@ class MkldnnAutoScanTest(AutoScanTest):
== IgnoreReasons.MKLDNN_ACCURACY_ERROR
):
self.ignore_log(
"[MKLDNN_ACCURACY_ERROR] "
+ ignore_info[2]
+ ' '
+ ' vs '
+ self.inference_config_str(pred_config)
f"[MKLDNN_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
)
else:
raise NotImplementedError
......@@ -315,28 +313,29 @@ class MkldnnAutoScanTest(AutoScanTest):
self.assert_tensors_near(
atol, rtol, results[-1], results[0]
)
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log(
self.inference_config_str(pred_config)
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
f"predictor_config: {self.inference_config_str(pred_config)}"
)
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
if not ignore_flag:
status = False
continue
self.success_log(
'RUN predictor_config '
+ self.inference_config_str(pred_config)
+ ' done'
)
self.assertTrue(status)
def inference_config_str(self, config) -> str:
dic = {}
enable_mkldnn = config.mkldnn_enabled()
dic['use_mkldnn'] = enable_mkldnn
dic["use_mkldnn"] = enable_mkldnn
enable_gpu = config.use_gpu()
dic['use_gpu'] = enable_gpu
dic["use_gpu"] = enable_gpu
return str(dic)
......@@ -351,7 +350,7 @@ class PassAutoScanTest(AutoScanTest):
if pass_name not in self.available_passes_in_framework:
continue
if not PassVersionChecker.IsCompatible(pass_name):
self.fail_log(f'{pass_name} version check failed.')
self.fail_log(f"{pass_name} version check failed.")
status = False
return status
......@@ -368,9 +367,7 @@ class PassAutoScanTest(AutoScanTest):
)
if not os.path.exists(last_passed_program):
raise ValueError(
"Cannot find file {}, please make sure that your pass name is correct".format(
last_passed_program
)
f"Cannot find file {last_passed_program}, please make sure that your pass name is correct"
)
model_bytes = paddle.static.load_from_file(last_passed_program)
pg = paddle.static.deserialize_program(model_bytes)
......@@ -382,9 +379,7 @@ class PassAutoScanTest(AutoScanTest):
after_op_list.append(main_block.op(i).type())
self.assertTrue(
op_list_after_fusion == after_op_list,
"Expected operator list after fusion is {}, but now it's {}".format(
op_list_after_fusion, after_op_list
),
f"Expected operator list after fusion is {op_list_after_fusion}, but now it's {after_op_list}",
)
def run_and_statis(
......@@ -396,10 +391,10 @@ class PassAutoScanTest(AutoScanTest):
max_duration=180,
passes=None,
):
if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
if os.getenv("HYPOTHESIS_TEST_PROFILE", "ci") == "dev":
max_examples *= 10
min_success_num *= 10
# while at ce phase, there's no limit on time
# while at ce phase, there"s no limit on time
max_duration = -1
start_time = time.time()
settings.register_profile(
......@@ -431,13 +426,11 @@ class PassAutoScanTest(AutoScanTest):
loop_func = reproduce(loop_func)
logging.info(f"Start to running test of {type(self)}")
loop_func()
logging.info(
self.info_log(
"===================Statistical Information==================="
)
logging.info(
"Number of Generated Programs: {}".format(
self.num_ran_programs + self.num_invalid_programs
)
self.info_log(
f"Number of Generated Programs: {self.num_ran_programs + self.num_invalid_programs}"
)
logging.info(f"Number of Invalid Programs: {self.num_invalid_programs}")
logging.info(f"Number of Ran Programs: {self.num_ran_programs}")
......@@ -446,27 +439,21 @@ class PassAutoScanTest(AutoScanTest):
self.num_ran_programs
- self.num_ignore_tests / max(self.num_predictor_kinds, 1)
)
logging.info(
"Number of successfully ran programs approximately equal to {}".format(
successful_ran_programs
)
self.info_log(
f"Number of successfully ran programs approximately equal to {successful_ran_programs}"
)
if successful_ran_programs < min_success_num:
logging.warning(
self.fail_log(
"satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
)
logging.error(
"At least {} programs need to ran successfully, but now only about {} programs satisfied.".format(
min_success_num, successful_ran_programs
)
self.fail_log(
f"At least {min_success_num} programs need to ran successfully, but now only about {successful_ran_programs} programs satisfied."
)
raise AssertionError()
used_time = time.time() - start_time
if max_duration > 0 and used_time > max_duration:
logging.error(
"The duration exceeds {} seconds, if this is necessary, try to set a larger number for parameter `max_duration`.".format(
max_duration
)
self.fail_log(
f"The duration exceeds {max_duration} seconds, if this is necessary, try to set a larger number for parameter `max_duration`."
)
raise AssertionError()
......@@ -486,11 +473,10 @@ class PassAutoScanTest(AutoScanTest):
feed_data = {}
for name, tensor_config in prog_config.inputs.items():
feed_data[name] = {
'data': tensor_config.data,
'lod': tensor_config.lod,
"data": tensor_config.data,
"lod": tensor_config.lod,
}
logging.info('RUN program_config: ' + str(prog_config))
self.num_predictor_kinds = 0
for (
pred_config,
......@@ -507,11 +493,7 @@ class PassAutoScanTest(AutoScanTest):
self.num_ignore_tests += 1
if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
self.ignore_log(
"[PASS_ACCURACY_ERROR] "
+ ignore_info[2]
+ ' '
+ ' vs '
+ self.inference_config_str(pred_config)
f"[PASS_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
)
else:
raise NotImplementedError
......@@ -532,9 +514,7 @@ class PassAutoScanTest(AutoScanTest):
model, params, prog_config, base_config, feed_data
)
self.success_log(
'RUN_BASELINE '
+ self.inference_config_str(base_config)
+ ' done'
f"baseline program_config: {self.inference_config_str(base_config)}"
)
if os.path.exists(self.cache_dir):
......@@ -549,19 +529,19 @@ class PassAutoScanTest(AutoScanTest):
if not ignore_flag:
self.assert_op_list(op_list)
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log(
self.inference_config_str(pred_config)
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
f"predictor_config: {self.inference_config_str(pred_config)}"
)
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
if not ignore_flag:
status = False
continue
self.success_log(
'RUN predictor_config '
+ self.inference_config_str(pred_config)
+ ' done'
)
status = self.check_op_version() and status
self.assertTrue(status)
......@@ -569,23 +549,23 @@ class PassAutoScanTest(AutoScanTest):
def inference_config_str(self, config) -> str:
dic = {}
enable_mkldnn = config.mkldnn_enabled()
dic['use_mkldnn'] = enable_mkldnn
dic["use_mkldnn"] = enable_mkldnn
enable_gpu = config.use_gpu()
dic['use_gpu'] = enable_gpu
enable_xpu = config.use_xpu()
dic['use_xpu'] = enable_xpu
if not self.passes:
dic['passes'] = self.passes
dic["passes"] = self.passes
enable_trt = config.tensorrt_engine_enabled()
trt_precison = config.tensorrt_precision_mode()
trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
if enable_trt:
dic['use_trt'] = True
dic['trt_precision'] = trt_precison
dic['use_dynamic_shape'] = trt_dynamic_shape
dic["use_trt"] = True
dic["trt_precision"] = trt_precison
dic["use_dynamic_shape"] = trt_dynamic_shape
else:
dic['use_trt'] = False
dic["use_trt"] = False
return str(dic)
def create_trt_inference_config(self) -> paddle_infer.Config:
......@@ -599,9 +579,9 @@ class PassAutoScanTest(AutoScanTest):
class TrtLayerAutoScanTest(AutoScanTest):
class TensorRTParam:
'''
"""
TensorRT subgraph engine parameters.
'''
"""
def __init__(
self,
......@@ -620,9 +600,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.use_calib_mode = use_calib_mode
class DynamicShapeParam:
'''
"""
Prepare TensorRT subgraph engine dynamic shape parameters.
'''
"""
def __init__(
self,
......@@ -648,7 +628,7 @@ class TrtLayerAutoScanTest(AutoScanTest):
)
self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
self.num_percent_cases = float(
os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')
os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")
)
# Use a separate random generator for skipping tests
......@@ -682,6 +662,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
)
return config
def get_avalible_input_type(self) -> List[np.dtype]:
return [np.float32]
def assert_tensors_near(
self,
atol: float,
......@@ -693,39 +676,32 @@ class TrtLayerAutoScanTest(AutoScanTest):
self.assertEqual(
baseline[key].shape,
arr.shape,
'The output shapes are not equal, the baseline shape is '
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
)
np.testing.assert_allclose(baseline[key], arr, rtol=rtol, atol=atol)
np.testing.assert_allclose(arr, baseline[key], rtol=rtol, atol=atol)
def assert_op_size(self, trt_engine_num, paddle_op_num):
last_passed_program = os.path.join(
self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel'
self.cache_dir, "transpose_flatten_concat_fuse_pass.pdmodel"
)
model_bytes = paddle.static.load_from_file(last_passed_program)
pg = paddle.static.deserialize_program(model_bytes)
main_block = pg.desc.block(0)
op_size = main_block.op_size()
op_types = [
main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size)
main_block.op(i).type() == "tensorrt_engine" for i in range(op_size)
]
trt_engine_size = sum(op_types)
paddle_op_size = op_size - trt_engine_size
self.assertEqual(
trt_engine_num,
trt_engine_size,
'Expected trt_engine_num is {}, but got {}!'.format(
trt_engine_num, trt_engine_size
),
f"Expected trt_engine_num is {trt_engine_num}, but got {trt_engine_size}!",
)
self.assertEqual(
paddle_op_num,
paddle_op_size,
'Expected paddle_op_num is {}, but got {}!'.format(
paddle_op_num, paddle_op_size
),
f"Expected paddle_op_num is {paddle_op_num}, but got {paddle_op_size}!",
)
def inference_config_str(self, config: paddle_infer.Config) -> str:
......@@ -734,11 +710,11 @@ class TrtLayerAutoScanTest(AutoScanTest):
trt_precison = config.tensorrt_precision_mode()
trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
if enable_trt:
dic['use_trt'] = True
dic['trt_precision'] = trt_precison
dic['use_dynamic_shape'] = trt_dynamic_shape
dic["use_trt"] = True
dic["trt_precision"] = trt_precison
dic["use_dynamic_shape"] = trt_dynamic_shape
else:
dic['use_trt'] = False
dic["use_trt"] = False
return str(dic)
def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
......@@ -765,110 +741,103 @@ class TrtLayerAutoScanTest(AutoScanTest):
feed_data = {}
for name, tensor_config in prog_config.inputs.items():
feed_data[name] = {
'data': tensor_config.data,
'lod': tensor_config.lod,
"data": tensor_config.data,
"lod": tensor_config.lod,
}
results: List[Dict[str, np.ndarray]] = []
if not skip_baseline:
# baseline: gpu run
logging.info('RUN program_config: ' + str(prog_config))
# baseline: gpu run, we only test float32
gpu_config = self.create_inference_config(use_trt=False)
results.append(
self.run_test_config(
model, params, prog_config, gpu_config, feed_data
)
baseline_result = self.run_test_config(
model,
params,
prog_config.set_input_type(np.float32),
gpu_config,
feed_data,
)
self.success_log('RUN_GPU_BASELINE done')
self.success_log(f"basline program_config: {prog_config}")
for (
pred_config,
nodes_num,
threshold,
) in self.sample_predictor_configs(prog_config):
for input_type in self.get_avalible_input_type():
prog_config = prog_config.set_input_type(input_type)
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
if isinstance(threshold, float):
atol = threshold
rtol = 1e-8
elif isinstance(threshold, (list, tuple)):
atol = threshold[0]
rtol = threshold[1]
else:
raise NotImplementedError
if (
pred_config.tensorrt_precision_mode()
!= paddle_infer.PrecisionType.Int8
and quant
):
continue
if (
pred_config.tensorrt_precision_mode()
== paddle_infer.PrecisionType.Int8
and not quant
):
continue
ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
ignore_flag = True
if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
self.ignore_log(
'[TRT_NOT_IMPLEMENTED] {} vs {}'.format(
note, self.inference_config_str(pred_config)
if isinstance(threshold, float):
atol = threshold
rtol = 1e-8
elif isinstance(threshold, list) or isinstance(
threshold, tuple
):
atol = threshold[0]
rtol = threshold[1]
else:
raise NotImplementedError
is_fp8 = (
pred_config.tensorrt_precision_mode()
== paddle_infer.PrecisionType.Int8
)
if (not is_fp8 and quant) or (is_fp8 and not quant):
continue
ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
ignore_flag = True
if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
self.ignore_log(
f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}"
)
)
elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
self.ignore_log(
'[TRT_NOT_SUPPORT] {} vs {}'.format(
note, self.inference_config_str(pred_config)
elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
self.ignore_log(
f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}"
)
)
else:
raise NotImplementedError
break
else:
raise NotImplementedError
break
if ignore_flag:
continue
if ignore_flag:
continue
try:
pred_config_deserialize = paddle_infer.Config(pred_config)
results.append(
self.run_test_config(
try:
pred_config_deserialize = paddle_infer.Config(
pred_config
)
trt_result = self.run_test_config(
model, params, prog_config, pred_config, feed_data
)
)
self.assert_tensors_near(
atol, rtol, results[-1], results[0]
)
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)
# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
self.assert_tensors_near(
atol, rtol, trt_result, baseline_result
)
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)
# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
)
self.success_log(
'RUN predictor_config {} done'.format(
self.inference_config_str(pred_config)
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
)
except Exception as e:
self.fail_log(
self.inference_config_str(pred_config)
+ f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
)
all_passes = False
except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
all_passes = False
self.assertTrue(all_passes)
......
......@@ -54,8 +54,8 @@ class TensorConfig:
if data_gen is not None:
self.data_gen = data_gen
self.data = data_gen()
self.dtype = data_gen().dtype
self.shape = data_gen().shape
self.dtype = self.data.dtype
self.shape = self.data.shape
else:
assert (
shape is not None
......@@ -67,6 +67,11 @@ class TensorConfig:
def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
def astype(self, type: np.dtype):
self.data = self.data.astype(type)
self.dtype = self.data.dtype
return self
class VarType(enum.Enum):
LOD_TENSOR = 1
......@@ -270,6 +275,16 @@ class ProgramConfig:
return log_str
def set_input_type(self, type: np.dtype):
for inp in self.inputs.values():
inp.astype(type)
for weight in self.weights.values():
weight.astype(type)
return self
def get_input_type(self) -> np.dtype:
return next(iter(self.inputs.values())).dtype
def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.'''
......
......@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
def generate_elewise_input():
return np.random.random([1, 12, 128, 128]).astype(np.float32)
def generate_weight(shape):
return np.random.random(shape).astype(np.float32)
mul_0 = OpConfig(
"mul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
......@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
),
},
weights={
"mul_0_w": TensorConfig(shape=[768, 768]),
"mul_1_w": TensorConfig(shape=[768, 768]),
"mul_2_w": TensorConfig(shape=[768, 768]),
"mul_3_w": TensorConfig(shape=[768, 768]),
"ele_0_w": TensorConfig(shape=[768]),
"ele_1_w": TensorConfig(shape=[768]),
"ele_2_w": TensorConfig(shape=[768]),
"mul_0_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_1_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_2_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_3_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"ele_0_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_1_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_2_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
},
outputs=[ops[-1].outputs["Out"][0]],
)
......
......@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version >= 8400:
if self.dims == 1 and not dynamic_shape:
if self.dims == 1:
return 0, 3
return 1, 2
else:
if (self.dims == 1 and not dynamic_shape) or (
if self.dims <= 2 or (
program_config.inputs['input_data'].dtype
in ['bool', 'int8', 'uint8']
):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册