diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc index 0f6421134c21655b9ffb4313d3459541d59a659e..12800bd26dae50d8d474e49b49691a8eb9c852b9 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc @@ -314,14 +314,17 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion( new_op_desc.SetType("fused_embedding_eltwise_layernorm"); new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Embs", embs); - new_op_desc.SetInput("WordId", {ids[0]}); new_op_desc.SetInput("PosId", {ids[1]}); - new_op_desc.SetInput("SentId", {ids[2]}); + if (ids.size() > 2) { + new_op_desc.SetInput("SentId", {ids[2]}); + } new_op_desc.SetInput("WordEmbedding", {embs[0]}); new_op_desc.SetInput("PosEmbedding", {embs[1]}); - new_op_desc.SetInput("SentEmbedding", {embs[2]}); + if (embs.size() > 2) { + new_op_desc.SetInput("SentEmbedding", {embs[2]}); + } new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); @@ -380,7 +383,6 @@ EmbeddingEltwiseLayerNormFusePass::EmbeddingEltwiseLayerNormFusePass() { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({0, -1}) .End(); AddOpCompat(OpCompat("layer_norm")) @@ -430,6 +432,6 @@ REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass, REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("lookup_table", 0) + .LE("lookup_table", 1) .LE("lookup_table_v2", 1) - .EQ("elementweise_add", 0)); + .LE("elementweise_add", 1)); diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index 1e7e0a3638fa758250288b9c925c14ef9f809d96..34a8f10458d7b39cf98747ba59e85d27cd6083e8 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -675,6 +675,8 @@ setlocal enabledelayedexpansion :: for /F %%# in ('cmd /C nvidia-smi -L ^|find "GPU" /C') do set CUDA_DEVICE_COUNT=%%# set CUDA_DEVICE_COUNT=1 +:: For hypothesis tests(mkldnn op and inference pass), we set use 'ci' profile +set HYPOTHESIS_TEST_PROFILE=ci echo cmake .. -G %GENERATOR% -DCMAKE_BUILD_TYPE=Release -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% ^ -DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DON_INFER=%ON_INFER% ^ -DWITH_INFERENCE_API_TEST=%WITH_INFERENCE_API_TEST% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% ^ @@ -692,6 +694,8 @@ echo ======================================== echo Running CPU unit tests in parallel way ... echo ======================================== +:: For hypothesis tests(mkldnn op and inference pass), we set use 'ci' profile +set HYPOTHESIS_TEST_PROFILE=ci %cache_dir%\tools\busybox64.exe bash %work_dir%\tools\windows\run_unittests.sh %NIGHTLY_MODE% %PRECISION_TEST% %WITH_GPU% goto:eof diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 3077d5ba55201db65d4850f25004d5a43d962ce2..50c30ba89be1921aa89e9383272603caef637645 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -581,12 +581,16 @@ EOF if [ "$1" == "cp36-cp36m" ]; then pip3.6 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl + pip3.6 install --user hypothesis elif [ "$1" == "cp37-cp37m" ]; then pip3.7 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl + pip3.7 install --user hypothesis elif [ "$1" == "cp38-cp38" ]; then pip3.8 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl + pip3.8 install --user hypothesis elif [ "$1" == "cp39-cp39" ]; then pip3.9 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl + pip3.9 install --user hypothesis fi tmpfile_rand=`date +%s%N` tmpfile=$tmp_dir/$tmpfile_rand @@ -1893,6 +1897,7 @@ set -ex function parallel_test() { mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build + pip install hypothesis pip install ${PADDLE_ROOT}/build/python/dist/*whl cp ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/op_test.py ${PADDLE_ROOT}/build/python ut_total_startTime_s=`date +%s` diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 927456b396ea5b6c6545209e33161b988b68f29a..43cdb85e75edd204c484e1b421363ed152f5aef3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -71,4 +71,9 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) +set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) + +if (WITH_MKLDNN) + set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300) +endif() endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index 6fc6ec875c68f6919cb8377b1f744d7d7ba11b8b..337098cde3c0deaf44c4fb7399ec2ca0d4ca61fb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -17,35 +17,70 @@ import unittest import abc import os import enum +import time import logging +import shutil import paddle import paddle.fluid as fluid from paddle.fluid.initializer import NumpyArrayInitializer +from paddle.fluid.core import PassVersionChecker import paddle.fluid.core as core from paddle import compat as cpt import paddle.inference as paddle_infer from typing import Optional, List, Callable, Dict, Any, Set from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model +import hypothesis +from hypothesis import given, settings, seed, example, assume + logging.basicConfig(level=logging.INFO, format="%(message)s") +settings.register_profile( + "ci", + max_examples=100, + suppress_health_check=hypothesis.HealthCheck.all(), + deadline=None, + print_blob=True, + derandomize=True, + report_multiple_bugs=False) +settings.register_profile( + "dev", + max_examples=1000, + suppress_health_check=hypothesis.HealthCheck.all(), + deadline=None, + print_blob=True, + derandomize=True, + report_multiple_bugs=False) +if float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1 or \ + os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci': + settings.load_profile("ci") +else: + settings.load_profile("dev") + class SkipReasons(enum.Enum): # Paddle not support, but trt support, we need to add the feature. TRT_NOT_IMPLEMENTED = 0 # TRT not support. TRT_NOT_SUPPORT = 1 + # Accuracy is abnormal after enabling pass. + PASS_ACCURACY_ERROR = 2 + # Accuracy is abnormal after enabling mkldnn. + MKLDNN_ACCURACY_ERROR = 3 class AutoScanTest(unittest.TestCase): - def __init__(self, methodName='runTest'): + def __init__(self, *args, **kwargs): np.random.seed(1024) paddle.enable_static() - super(AutoScanTest, self).__init__(methodName) + super(AutoScanTest, self).__init__(*args, **kwargs) self.skip_cases = [] + abs_dir = os.path.abspath(os.path.dirname(__file__)) + self.cache_dir = os.path.join(abs_dir, + str(self.__module__) + '_cache_dir') @abc.abstractmethod - def sample_program_configs(self) -> List[ProgramConfig]: + def sample_program_configs(self): ''' Generate all config with the combination of different Input tensor shape and different Attr values. @@ -53,7 +88,7 @@ class AutoScanTest(unittest.TestCase): raise NotImplementedError @abc.abstractmethod - def sample_predictor_configs(self) -> List[paddle_infer.Config]: + def sample_predictor_configs(self): raise NotImplementedError @abc.abstractmethod @@ -88,21 +123,488 @@ class AutoScanTest(unittest.TestCase): result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu() return result + @abc.abstractmethod def assert_tensors_near(self, - threshold: float, - tensors: List[Dict[str, np.array]]): - assert len(tensors) > 1 - first = tensors[0] - for group in tensors[1:]: - for key, arr in group.items(): - self.assertTrue( - first[key].shape == arr.shape, - "The output shape of GPU and TensorRT are not equal.") - self.assertTrue( - np.allclose( - first[key], arr, atol=threshold), - "Output has diff between GPU and TensorRT. ") + atol: float, + rtol: float, + tensor: Dict[str, np.array], + baseline: Dict[str, np.array]): + for key, arr in tensor.items(): + self.assertTrue( + baseline[key].shape == arr.shape, + "The output shapes are not equal, the baseline shape is " + + str(baseline[key].shape) + ', but got ' + str(arr.shape)) + self.assertTrue( + np.allclose( + baseline[key], arr, atol=atol, rtol=rtol), + "Output has diff. ") @abc.abstractmethod def run_test(self, quant=False): raise NotImplementedError + + def generate_op_config(self, + ops_config: List[Dict[str, Any]]) -> List[OpConfig]: + ops = [] + for i in range(len(ops_config)): + op_config = ops_config[i] + ops.append( + OpConfig( + type=op_config['op_type'], + inputs=op_config['op_inputs'], + outputs=op_config['op_outputs'], + attrs=op_config['op_attrs'])) + return ops + + @abc.abstractmethod + def skip_log(self, msg: str): + logging.warning("SKIP: " + msg) + + @abc.abstractmethod + def fail_log(self, msg: str): + logging.error("FAILE: " + msg) + + @abc.abstractmethod + def success_log(self, msg: str): + logging.info("SUCCESS: " + msg) + + @abc.abstractmethod + def create_inference_config(self, + passes: Optional[List[str]]=None, + use_gpu: bool=False, + use_mkldnn: bool=False, + ir_optim: Optional[bool]=None): + config = paddle_infer.Config() + config.switch_ir_debug(True) + config.set_optim_cache_dir(self.cache_dir) + config.disable_glog_info() + if ir_optim is not None: + config.switch_ir_optim(ir_optim) + if use_gpu: + config.enable_use_gpu(100, 0) + if use_mkldnn: + config.enable_mkldnn() + if passes is not None: + config.pass_builder().set_passes(passes) + self.passes = passes + return config + + +class MkldnnAutoScanTest(AutoScanTest): + def __init__(self, *args, **kwargs): + super(MkldnnAutoScanTest, self).__init__(*args, **kwargs) + + def run_test(self, quant=False, *args, **kwargs): + status = True + + for prog_config in self.sample_program_configs(*args, **kwargs): + # if program is invalid, we should skip that cases. + if not self.is_program_valid(prog_config): + continue + + model, params = create_fake_model(prog_config) + if quant: + model, params = create_quant_model(model, params) + + feed_data = {} + for name, tensor_config in prog_config.inputs.items(): + feed_data[name] = { + 'data': tensor_config.data, + 'lod': tensor_config.lod + } + results: List[Dict[str, np.ndarray]] = [] + + # baseline: cpu no ir_optim run + base_config = self.create_inference_config(ir_optim=False) + logging.info('RUN program_config: ' + str(prog_config)) + results.append( + self.run_test_config(model, params, prog_config, base_config, + feed_data)) + self.success_log('RUN_CPU_BASELINE done') + + for pred_config, ( + atol, rtol) in self.sample_predictor_configs(prog_config): + # skip info + skip_flag = False + for skip_info in self.skip_cases: + if skip_info[0](prog_config, pred_config): + skip_flag = True + if skip_info[1] == SkipReasons.MKLDNN_ACCURACY_ERROR: + self.skip_log("[MKLDNN_ACCURACY_ERROR] " + + skip_info[2] + ' ' + ' vs ' + self. + inference_config_str(pred_config)) + else: + raise NotImplementedError + break + + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + if not os.path.exists(self.cache_dir): + os.mkdir(self.cache_dir) + + try: + results.append( + self.run_test_config(model, params, prog_config, + pred_config, feed_data)) + self.assert_tensors_near(atol, rtol, results[-1], + results[0]) + except Exception as e: + self.fail_log( + self.inference_config_str(pred_config) + + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))) + if not skip_flag: + status = False + continue + self.success_log('RUN predictor_config ' + self. + inference_config_str(pred_config) + ' done') + + self.assertTrue(status) + + def inference_config_str(self, config) -> bool: + dic = {} + enable_mkldnn = config.mkldnn_enabled() + dic['use_mkldnn'] = enable_mkldnn + enable_gpu = config.use_gpu() + dic['use_gpu'] = enable_gpu + return str(dic) + + +class PassAutoScanTest(AutoScanTest): + def __init__(self, *args, **kwargs): + super(PassAutoScanTest, self).__init__(*args, **kwargs) + self.passes = [] + + def check_op_version(self): + status = True + for pass_name in self.passes: + if not PassVersionChecker.IsCompatible(pass_name): + self.fail_log('{} version check failed.'.format(pass_name)) + status = False + return status + + def assert_op_size(self, fusion_before_num, fusion_after_num, origin_model): + if not self.passes: + raise ValueError( + 'In PassAutoScan you should give a valid pass name.') + last_passed_program = os.path.join(self.cache_dir, + self.passes[-1] + '.pdmodel') + model_bytes = paddle.static.load_from_file(last_passed_program) + pg = paddle.static.deserialize_program(model_bytes) + main_block = pg.desc.block(0) + after_op_size = main_block.op_size() + pg = paddle.static.deserialize_program(origin_model) + main_block = pg.desc.block(0) + before_op_size = main_block.op_size() + self.assertTrue(before_op_size == fusion_before_num, + 'before fusion op size is {}, but got {}!'.format( + before_op_size, fusion_before_num)) + self.assertTrue(after_op_size == fusion_after_num, + 'after fusion op size is {}, but got {}!'.format( + after_op_size, fusion_after_num)) + + def run_test(self, quant=False, *args, **kwargs): + status = True + + for prog_config in self.sample_program_configs(*args, **kwargs): + # if program is invalid, we should skip that cases. + if not self.is_program_valid(prog_config): + continue + + model, params = create_fake_model(prog_config) + if quant: + model, params = create_quant_model(model, params) + + feed_data = {} + for name, tensor_config in prog_config.inputs.items(): + feed_data[name] = { + 'data': tensor_config.data, + 'lod': tensor_config.lod + } + results: List[Dict[str, np.ndarray]] = [] + + # baseline: cpu no ir_optim run + base_config = self.create_inference_config(ir_optim=False) + logging.info('RUN program_config: ' + str(prog_config)) + results.append( + self.run_test_config(model, params, prog_config, base_config, + feed_data)) + self.success_log('RUN_CPU_BASELINE done') + + for pred_config, nodes_num, ( + atol, rtol) in self.sample_predictor_configs(prog_config): + # skip info + skip_flag = False + for skip_info in self.skip_cases: + if skip_info[0](prog_config, pred_config): + skip_flag = True + if skip_info[1] == SkipReasons.PASS_ACCURACY_ERROR: + self.skip_log("[PASS_ACCURACY_ERROR] " + skip_info[ + 2] + ' ' + ' vs ' + self.inference_config_str( + pred_config)) + else: + raise NotImplementedError + break + + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + if not os.path.exists(self.cache_dir): + os.mkdir(self.cache_dir) + + try: + results.append( + self.run_test_config(model, params, prog_config, + pred_config, feed_data)) + self.assert_tensors_near(atol, rtol, results[-1], + results[0]) + if not skip_flag: + self.assert_op_size(nodes_num[0], nodes_num[1], model) + + except Exception as e: + self.fail_log( + self.inference_config_str(pred_config) + + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))) + if not skip_flag: + status = False + continue + self.success_log('RUN predictor_config ' + self. + inference_config_str(pred_config) + ' done') + + status = self.check_op_version() and status + self.assertTrue(status) + + def inference_config_str(self, config) -> bool: + dic = {} + enable_mkldnn = config.mkldnn_enabled() + dic['use_mkldnn'] = enable_mkldnn + enable_gpu = config.use_gpu() + dic['use_gpu'] = enable_gpu + if not self.passes: + dic['passes'] = self.passes + + enable_trt = config.tensorrt_engine_enabled() + trt_precison = config.tensorrt_precision_mode() + trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() + if enable_trt: + dic['use_trt'] = True + dic['trt_precision'] = trt_precison + dic['use_dynamic_shape'] = trt_dynamic_shape + else: + dic['use_trt'] = False + return str(dic) + + def create_trt_inference_config(self) -> paddle_infer.Config: + config = paddle_infer.Config() + config.disable_glog_info() + config.enable_use_gpu(100, 0) + config.set_optim_cache_dir(self.cache_dir) + config.switch_ir_debug() + # for assert_op_size. + self.passes = ['transpose_flatten_concat_fuse_pass'] + return config + + +class TrtLayerAutoScanTest(AutoScanTest): + class TensorRTParam: + ''' + TensorRT subgraph engine parameters. + ''' + + def __init__(self, workspace_size, max_batch_size, min_subgraph_size, + precision, use_static, use_calib_mode): + self.workspace_size = workspace_size + self.max_batch_size = max_batch_size + self.min_subgraph_size = min_subgraph_size + self.precision = precision + self.use_static = use_static + self.use_calib_mode = use_calib_mode + + class DynamicShapeParam: + ''' + Prepare TensorRT subgraph engine dynamic shape parameters. + ''' + + def __init__(self, min_input_shape, max_input_shape, opt_input_shape, + disable_trt_plugin_fp16): + self.min_input_shape = min_input_shape + self.max_input_shape = max_input_shape + self.opt_input_shape = opt_input_shape + self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16 + + def __init__(self, *args, **kwargs): + super(TrtLayerAutoScanTest, self).__init__(*args, **kwargs) + self.trt_param = self.TensorRTParam( + workspace_size=1024, + max_batch_size=4, + min_subgraph_size=0, + precision=paddle_infer.PrecisionType.Float32, + use_static=True, + use_calib_mode=False) + self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) + self.num_percent_cases = float( + os.getenv( + 'TEST_NUM_PERCENT_CASES', default='1.0')) + # Choose different tests by week + np.random.seed(int(time.strftime("%W"))) + + def create_inference_config(self, use_trt=True) -> paddle_infer.Config: + config = paddle_infer.Config() + config.disable_glog_info() + config.enable_use_gpu(100, 0) + config.set_optim_cache_dir(self.cache_dir) + if use_trt: + config.switch_ir_debug() + config.enable_tensorrt_engine( + max_batch_size=self.trt_param.max_batch_size, + workspace_size=self.trt_param.workspace_size, + min_subgraph_size=self.trt_param.min_subgraph_size, + precision_mode=self.trt_param.precision, + use_static=self.trt_param.use_static, + use_calib_mode=self.trt_param.use_calib_mode) + if len(self.dynamic_shape.min_input_shape + ) != 0 and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.max_input_shape.keys( + ) and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.opt_input_shape.keys(): + config.set_trt_dynamic_shape_info( + self.dynamic_shape.min_input_shape, + self.dynamic_shape.max_input_shape, + self.dynamic_shape.opt_input_shape, + self.dynamic_shape.disable_trt_plugin_fp16) + return config + + def assert_op_size(self, trt_engine_num, paddle_op_num): + last_passed_program = os.path.join( + self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel') + model_bytes = paddle.static.load_from_file(last_passed_program) + pg = paddle.static.deserialize_program(model_bytes) + main_block = pg.desc.block(0) + op_size = main_block.op_size() + op_types = [ + main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size) + ] + trt_engine_size = sum(op_types) + paddle_op_size = op_size - trt_engine_size + self.assertTrue(trt_engine_size == trt_engine_num, + 'trt_engine_num is {}, but got {}!'.format( + trt_engine_size, trt_engine_num)) + self.assertTrue(paddle_op_size == paddle_op_num, + 'paddle_op_num is {}, but got {}!'.format( + paddle_op_size, paddle_op_num)) + + def inference_config_str(self, config: paddle_infer.Config): + dic = {} + enable_trt = config.tensorrt_engine_enabled() + trt_precison = config.tensorrt_precision_mode() + trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() + if enable_trt: + dic['use_trt'] = True + dic['trt_precision'] = trt_precison + dic['use_dynamic_shape'] = trt_dynamic_shape + else: + dic['use_trt'] = False + return str(dic) + + def run_test(self, quant=False, *args, **kwargs): + status = True + run_flags = [] + for prog_config in self.sample_program_configs(*args, **kwargs): + # In CI, only run 10% cases + if np.random.rand() < self.num_percent_cases: + run_flags.append(True) + else: + run_flags.append(False) + + for prog_config, run_flags in zip( + self.sample_program_configs(*args, **kwargs), run_flags): + if not run_flags: + continue + + # if program is invalid, we should skip that cases. + if not self.is_program_valid(prog_config): + continue + + model, params = create_fake_model(prog_config) + if quant: + model, params = create_quant_model(model, params) + + feed_data = {} + for name, tensor_config in prog_config.inputs.items(): + feed_data[name] = { + 'data': tensor_config.data, + 'lod': tensor_config.lod + } + + results: List[Dict[str, np.ndarray]] = [] + + # baseline: gpu run + logging.info('RUN program_config: ' + str(prog_config)) + gpu_config = self.create_inference_config(use_trt=False) + results.append( + self.run_test_config(model, params, prog_config, gpu_config, + feed_data)) + self.success_log('RUN_GPU_BASELINE done') + + for pred_config, nodes_num, threshold in self.sample_predictor_configs( + prog_config): + + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + + if isinstance(threshold, float): + atol = threshold + rtol = 1e-8 + elif isinstance(threshold, list) or isinstance(threshold, + tuple): + atol = threshold[0] + rtol = threshold[1] + else: + raise NotImplementedError + + if quant and pred_config.tensorrt_precision_mode( + ) != paddle_infer.PrecisionType.Int8: + continue + if pred_config.tensorrt_precision_mode( + ) == paddle_infer.PrecisionType.Int8 and not quant: + continue + + skip_flag = False + for skip_info in self.skip_cases: + if skip_info[0](prog_config, pred_config): + skip_flag = True + if skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED: + self.skip_log("[TRT_NOT_IMPLEMENTED] " + skip_info[ + 2] + ' ' + ' vs ' + self.inference_config_str( + pred_config)) + elif skip_info[1] == SkipReasons.TRT_NOT_SUPPORT: + self.skip_log("[TRT_NOT_SUPPORT] " + skip_info[ + 2] + ' ' + ' vs ' + self.inference_config_str( + pred_config)) + else: + raise NotImplementedError + break + + try: + pred_config_deserialize = paddle_infer.Config(pred_config) + results.append( + self.run_test_config(model, params, prog_config, + pred_config, feed_data)) + self.assert_tensors_near(atol, rtol, results[-1], + results[0]) + if not skip_flag: + self.assert_op_size(nodes_num[0], nodes_num[1]) + # deserialize test + if nodes_num[0] > 0: + self.run_test_config(model, params, prog_config, + pred_config_deserialize, feed_data) + except Exception as e: + self.fail_log( + str(prog_config) + ' vs ' + self.inference_config_str( + pred_config) + + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))) + if not skip_flag: + status = False + continue + self.success_log('RUN predictor_config ' + self. + inference_config_str(pred_config) + ' done') + + self.assertTrue(status) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2046307e5c518cce8161ef441d3286d9c8e5585c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py @@ -0,0 +1,321 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): + ''' + in_var1 emb_var in_var2 emb_var in_var3 emb_var in_var emb_var + | | | | | | | | + lookup_table lookup_table lookup_table ... lookup_table + | | | | + lkt_var lkt_var lkt_var lkt_var + \ / | ... | + elementwise_add | | + \ / | + elementwise_add | + | | + elt_var / + \ / + elementwise_add + | + layer_norm + ''' + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + # is_sparse is only support False + if program_config.ops[0].attrs['is_sparse'] == True: + return False + + # is_distributed only support False + if program_config.ops[0].attrs['is_distributed'] == True: + return False + + # axis only support -1 and the last dim. + if program_config.ops[3].attrs['axis'] not in [-1, 2]: + return False + + if not (program_config.ops[5].attrs['epsilon'] >= 0 and + program_config.ops[5].attrs['epsilon'] <= 0.001): + return False + + if program_config.ops[5].attrs['begin_norm_axis'] != 2: + return False + + # input check + if program_config.weights['embedding_weight1'].shape[ + 1] != program_config.weights['layer_norm_scale'].shape[0]: + return False + + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(attrs): + if attrs[0]['op_type'] == 'lookup_table': + return np.random.randint( + 0, + attrs[3]['weight_size'][0], + size=(attrs[3]['batch_size'], attrs[3]['input_dim'], + 1)).astype(np.int64) + else: + return np.random.randint( + 0, + attrs[3]['weight_size'][0], + size=(attrs[3]['batch_size'], + attrs[3]['input_dim'])).astype(np.int64) + + def generate_weight1(attrs): + # set embedding weight by attrs + return np.random.random(attrs['weight_size']).astype(np.float32) + + def generate_weight2(attrs): + # set layernorm weight by attrs + if attrs[2]['begin_norm_axis'] == 1: + return np.random.random( + attrs[3]['input_dim'] * + attrs[3]['weight_size'][1]).astype(np.float32) + else: + return np.random.random(attrs[3]['weight_size'][1]).astype( + np.float32) + + attrs = [{ + 'is_sparse': kwargs['is_sparse'], + 'is_distributed': kwargs['is_distributed'], + 'padding_idx': kwargs['padding_idx'], + 'op_type': kwargs['op_type'] + }, { + 'axis': kwargs['axis'] + }, { + 'begin_norm_axis': kwargs['begin_norm_axis'], + 'epsilon': kwargs['epsilon'] + }, { + 'batch_size': kwargs['batch_size'], + 'input_dim': kwargs['input_dim'], + 'weight_size': kwargs['weight_size'] + }] + + ops_config = [{ + "op_type": attrs[0]['op_type'], + "op_inputs": { + "Ids": ["input_data1"], + "W": ["embedding_weight1"] + }, + "op_outputs": { + "Out": ["embedding_output1"] + }, + "op_attrs": { + 'is_sparse': attrs[0]['is_sparse'], + 'is_distributed': attrs[0]['is_distributed'], + 'padding_idx': attrs[0]['padding_idx'], + } + }, { + "op_type": attrs[0]['op_type'], + "op_inputs": { + "Ids": ["input_data2"], + "W": ["embedding_weight2"] + }, + "op_outputs": { + "Out": ["embedding_output2"] + }, + "op_attrs": { + 'is_sparse': attrs[0]['is_sparse'], + 'is_distributed': attrs[0]['is_distributed'], + 'padding_idx': attrs[0]['padding_idx'], + }, + }, { + "op_type": attrs[0]['op_type'], + "op_inputs": { + "Ids": ["input_data3"], + "W": ["embedding_weight3"] + }, + "op_outputs": { + "Out": ["embedding_output3"] + }, + "op_attrs": { + 'is_sparse': attrs[0]['is_sparse'], + 'is_distributed': attrs[0]['is_distributed'], + 'padding_idx': attrs[0]['padding_idx'], + }, + }, { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["embedding_output2"], + "Y": ["embedding_output3"] + }, + "op_outputs": { + "Out": ["elementwise_add_output1"] + }, + "op_attrs": { + "axis": attrs[1]['axis'], + } + }, { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["elementwise_add_output1"], + "Y": ["embedding_output1"] + }, + "op_outputs": { + "Out": ["elementwise_add_output2"] + }, + "op_attrs": { + "axis": attrs[1]['axis'], + } + }, { + "op_type": "layer_norm", + "op_inputs": { + "X": ["elementwise_add_output2"], + "Bias": ["layer_norm_bias"], + "Scale": ["layer_norm_scale"] + }, + "op_outputs": { + "Y": ["layer_norm_output1"], + "Mean": ["layer_norm_output2"], + "Variance": ["layer_norm_output3"] + }, + "op_attrs": { + 'begin_norm_axis': attrs[2]['begin_norm_axis'], + 'epsilon': attrs[2]['epsilon'], + } + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "embedding_weight1": + TensorConfig(data_gen=partial(generate_weight1, attrs[3])), + "embedding_weight2": + TensorConfig(data_gen=partial(generate_weight1, attrs[3])), + "embedding_weight3": + TensorConfig(data_gen=partial(generate_weight1, attrs[3])), + "layer_norm_bias": + TensorConfig(data_gen=partial(generate_weight2, attrs)), + "layer_norm_scale": + TensorConfig(data_gen=partial(generate_weight2, attrs)) + }, + inputs={ + "input_data1": + TensorConfig(data_gen=partial(generate_input, attrs)), + "input_data2": + TensorConfig(data_gen=partial(generate_input, attrs)), + "input_data3": + TensorConfig(data_gen=partial(generate_input, attrs)) + }, + outputs=["layer_norm_output1"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + # only used in gpu passes and trt passes. + config = self.create_inference_config( + passes=['embedding_eltwise_layernorm_fuse_pass'], use_gpu=True) + yield config, (10, 5), (1e-5, 1e-5) + # trt static_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=4, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, (10, 3), (1e-5, 1e-5) + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=4, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + if program_config.ops[0].type == 'lookup_table': + config.set_trt_dynamic_shape_info({ + "input_data1": [1, 4, 1], + "input_data2": [1, 4, 1], + "input_data3": [1, 4, 1] + }, { + "input_data1": [4, 512, 1], + "input_data2": [4, 512, 1], + "input_data3": [4, 512, 1] + }, { + "input_data1": [2, 128, 1], + "input_data2": [2, 128, 1], + "input_data3": [2, 128, 1] + }) + else: + config.set_trt_dynamic_shape_info({ + "input_data1": [1, 4], + "input_data2": [1, 4], + "input_data3": [1, 4] + }, { + "input_data1": [4, 512], + "input_data2": [4, 512], + "input_data3": [4, 512] + }, { + "input_data1": [2, 128], + "input_data2": [2, 128], + "input_data3": [2, 128] + }) + yield config, (10, 3), (1e-5, 1e-5) + + def add_skip_pass_case(self): + def teller1(program_config, predictor_config): + if program_config.ops[3].attrs['axis'] in [ + -1, 2 + ] and program_config.ops[5].attrs[ + 'begin_norm_axis'] == 2 and program_config.weights[ + 'embedding_weight1'].shape in [(64, 32), (64, 64)]: + return True + return False + + self.add_skip_case(teller1, SkipReasons.PASS_ACCURACY_ERROR, + "The pass output has diff in a specific case.") + + @given( + is_sparse=st.booleans(), + is_distributed=st.booleans(), + padding_idx=st.integers(), + axis=st.integers( + min_value=-4, max_value=4), + op_type=st.sampled_from(['lookup_table', 'lookup_table_v2']), + epsilon=st.floats( + min_value=0, max_value=0.001), + begin_norm_axis=st.integers( + min_value=-4, max_value=4), + batch_size=st.integers( + min_value=1, max_value=4), + input_dim=st.sampled_from([32, 64]), + weight_size=st.sampled_from([[64, 64], [64, 32]])) + def test(self, *args, **kwargs): + assume(kwargs['begin_norm_axis'] == 2) + + self.add_skip_pass_case() + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..32642096c76c6ebadbb4eead4e590ceb05ee7d87 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestMkldnnPreluOp(MkldnnAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + # if mode is channel, and in_shape is 1 rank + if len(program_config.inputs['input_data']. + shape) == 1 and program_config.ops[0].attrs['mode'] == 'channel': + return False + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(np.float32) + + def generate_alpha(*args, **kwargs): + if kwargs["mode"] == "all": + return np.random.random(size=(1)).astype(np.float32) + elif kwargs["mode"] == "channel": + if len(kwargs['in_shape']) <= 1: + # not valid case, just return 0 + return np.zeros((1)).astype(np.float32) + return np.random.random(kwargs['in_shape'][1]).astype( + np.float32) + else: + if len(kwargs['in_shape']) <= 1: + # not valid case, just return 0 + return np.zeros((1)).astype(np.float32) + return np.random.random(kwargs['in_shape']).astype(np.float32) + + ops_config = [{ + "op_type": "prelu", + "op_inputs": { + "X": ["input_data"], + "Alpha": ["alpha_weight"] + }, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": { + "mode": kwargs['mode'] + } + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "alpha_weight": + TensorConfig(data_gen=partial(generate_alpha, *args, **kwargs)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + def add_skip_pass_case(self): + pass + + @given( + mode=st.sampled_from(['all', 'channel', 'element']), + in_shape=st.lists( + st.integers( + min_value=1, max_value=32), min_size=1, max_size=4)) + def test(self, *args, **kwargs): + self.add_skip_pass_case() + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tile.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tile.py index c1a5493fd328a0a554c9ee7d37096945b1e0fa79..cbbd13a7b8003e549ef8b8f3084e6d1c3bc4eb39 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tile.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tile.py @@ -20,6 +20,10 @@ from functools import partial from typing import Optional, List, Callable, Dict, Any, Set import unittest +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + class TrtConvertTileTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: @@ -34,35 +38,34 @@ class TrtConvertTileTest(TrtLayerAutoScanTest): return True - def sample_program_configs(self): + def sample_program_configs(self, *args, **kwargs): def generate_input1(attrs: List[Dict[str, Any]]): return np.ones([1, 2, 3, 4]).astype(np.float32) - for repeat_times in [[100], [1, 2], [0, 3], [1, 2, 100]]: - dics = [{"repeat_times": repeat_times}] - - ops_config = [{ - "op_type": "tile", - "op_inputs": { - "X": ["input_data"] - }, - "op_outputs": { - "Out": ["tile_output_data"] - }, - "op_attrs": dics[0] - }] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig(data_gen=partial(generate_input1, - dics)) - }, - outputs=["tile_output_data"]) - - yield program_config + dics = [{"repeat_times": kwargs['repeat_times']}] + + ops_config = [{ + "op_type": "tile", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["tile_output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input1, + dics)) + }, + outputs=["tile_output_data"]) + + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): @@ -109,8 +112,9 @@ class TrtConvertTileTest(TrtLayerAutoScanTest): yield self.create_inference_config(), generate_trt_nodes_num(attrs, True), 1e-4 - def test(self): - self.run_test() + @given(repeat_times=st.sampled_from([[100], [1, 2], [0, 3], [1, 2, 100]])) + def test(self, *args, **kwargs): + self.run_test(*args, **kwargs) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py index 941641da7a30dc1f2d9d949148b23ea99c827b40..7432101e787c29e2ba9b8cc78beca8140ad13c8f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -12,277 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import unittest -import itertools -import abc -import enum -import sys -import os -import logging -import time -import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -import paddle.inference as paddle_infer -import shutil - -from paddle import compat as cpt -from typing import Optional, List, Callable, Dict, Any, Set -from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model -from auto_scan_test import AutoScanTest, SkipReasons - -logging.basicConfig(level=logging.INFO, format="%(message)s") - - -class TrtLayerAutoScanTest(AutoScanTest): - class TensorRTParam: - ''' - TensorRT subgraph engine parameters. - ''' - - def __init__(self, workspace_size, max_batch_size, min_subgraph_size, - precision, use_static, use_calib_mode): - self.workspace_size = workspace_size - self.max_batch_size = max_batch_size - self.min_subgraph_size = min_subgraph_size - self.precision = precision - self.use_static = use_static - self.use_calib_mode = use_calib_mode - - class DynamicShapeParam: - ''' - Prepare TensorRT subgraph engine dynamic shape parameters. - ''' - - def __init__(self, min_input_shape, max_input_shape, opt_input_shape, - disable_trt_plugin_fp16): - self.min_input_shape = min_input_shape - self.max_input_shape = max_input_shape - self.opt_input_shape = opt_input_shape - self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16 - - def __init__(self, methodName='runTest'): - super(TrtLayerAutoScanTest, self).__init__(methodName) - self.trt_param = self.TensorRTParam( - workspace_size=1024, - max_batch_size=4, - min_subgraph_size=0, - precision=paddle_infer.PrecisionType.Float32, - use_static=True, - use_calib_mode=False) - self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) - self.num_percent_cases = float( - os.getenv( - 'TEST_NUM_PERCENT_CASES', default='1.0')) - abs_dir = os.path.abspath(os.path.dirname(__file__)) - cache_dir = str(self.__module__) + '_trt_cache_dir' - self.trt_cache_dir = os.path.join(abs_dir, cache_dir) - - def create_inference_config(self, use_trt=True) -> paddle_infer.Config: - config = paddle_infer.Config() - config.disable_glog_info() - config.enable_use_gpu(100, 0) - config.set_optim_cache_dir(self.trt_cache_dir) - if use_trt: - config.switch_ir_debug() - config.enable_tensorrt_engine( - max_batch_size=self.trt_param.max_batch_size, - workspace_size=self.trt_param.workspace_size, - min_subgraph_size=self.trt_param.min_subgraph_size, - precision_mode=self.trt_param.precision, - use_static=self.trt_param.use_static, - use_calib_mode=self.trt_param.use_calib_mode) - if len(self.dynamic_shape.min_input_shape - ) != 0 and self.dynamic_shape.min_input_shape.keys( - ) == self.dynamic_shape.max_input_shape.keys( - ) and self.dynamic_shape.min_input_shape.keys( - ) == self.dynamic_shape.opt_input_shape.keys(): - config.set_trt_dynamic_shape_info( - self.dynamic_shape.min_input_shape, - self.dynamic_shape.max_input_shape, - self.dynamic_shape.opt_input_shape, - self.dynamic_shape.disable_trt_plugin_fp16) - return config - - def assert_tensors_near(self, - atol: float, - rtol: float, - tensor: Dict[str, np.array], - baseline: Dict[str, np.array]): - for key, arr in tensor.items(): - self.assertTrue( - baseline[key].shape == arr.shape, - "The output shape of GPU and TensorRT are not equal, the baseline shape is " - + str(baseline[key].shape) + ', but the trt shape is ' + - str(arr.shape)) - self.assertTrue( - np.allclose( - baseline[key], arr, atol=atol, rtol=rtol), - "Output has diff between GPU and TensorRT. ") - - def assert_op_size(self, trt_engine_num, paddle_op_num): - last_passed_program = os.path.join( - self.trt_cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel') - model_bytes = paddle.static.load_from_file(last_passed_program) - pg = paddle.static.deserialize_program(model_bytes) - main_block = pg.desc.block(0) - op_size = main_block.op_size() - op_types = [ - main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size) - ] - trt_engine_size = sum(op_types) - paddle_op_size = op_size - trt_engine_size - self.assertTrue(trt_engine_size == trt_engine_num, - 'trt_engine_num is {}, but got {}!'.format( - trt_engine_size, trt_engine_num)) - self.assertTrue(paddle_op_size == paddle_op_num, - 'paddle_op_num is {}, but got {}!'.format( - paddle_op_size, paddle_op_num)) - - def skip_log(self, msg: str): - logging.warning("SKIP: " + msg) - - def fail_log(self, msg: str): - logging.error("FAILE: " + msg) - - def success_log(self, msg: str): - logging.info("SUCCESS: " + msg) - - def validate(self, func: Callable[..., bool]): - pass - - def generate_op_config(self, - ops_config: List[Dict[str, Any]]) -> List[OpConfig]: - ops = [] - for i in range(len(ops_config)): - op_config = ops_config[i] - ops.append( - OpConfig( - type=op_config['op_type'], - inputs=op_config['op_inputs'], - outputs=op_config['op_outputs'], - attrs=op_config['op_attrs'])) - return ops - - def inference_config_str(self, config: paddle_infer.Config): - dic = {} - enable_trt = config.tensorrt_engine_enabled() - trt_precison = config.tensorrt_precision_mode() - trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() - if enable_trt: - dic['use_trt'] = True - dic['trt_precision'] = trt_precison - dic['use_dynamic_shape'] = trt_dynamic_shape - else: - dic['use_trt'] = False - return str(dic) - - def run_test(self, quant=False): - status = True - # Choose different tests by week - np.random.seed(int(time.strftime("%W"))) - run_flags = [] - for prog_config in self.sample_program_configs(): - # In CI, only run 30% cases - if np.random.rand() < self.num_percent_cases: - run_flags.append(True) - else: - run_flags.append(False) - np.random.seed(1024) - - for prog_config, run_flags in zip(self.sample_program_configs(), - run_flags): - if not run_flags: - continue - - # if program is invalid, we should skip that cases. - if not self.is_program_valid(prog_config): - continue - - model, params = create_fake_model(prog_config) - if quant: - model, params = create_quant_model(model, params) - - feed_data = {} - for name, tensor_config in prog_config.inputs.items(): - feed_data[name] = { - 'data': tensor_config.data, - 'lod': tensor_config.lod - } - - results: List[Dict[str, Tensor]] = [] - - # baseline: gpu run - gpu_config = self.create_inference_config(use_trt=False) - results.append( - self.run_test_config(model, params, prog_config, gpu_config, - feed_data)) - self.success_log('RUN_GPU_BASELINE ' + str(prog_config) + ' vs ' + - self.inference_config_str(gpu_config)) - - for pred_config, nodes_num, threshold in self.sample_predictor_configs( - prog_config): - - if os.path.exists(self.trt_cache_dir): - shutil.rmtree(self.trt_cache_dir) - - if isinstance(threshold, float): - atol = threshold - rtol = 1e-8 - elif isinstance(threshold, list) or isinstance(threshold, - tuple): - atol = threshold[0] - rtol = threshold[1] - else: - raise NotImplementedError - - if quant and pred_config.tensorrt_precision_mode( - ) != paddle_infer.PrecisionType.Int8: - continue - if pred_config.tensorrt_precision_mode( - ) == paddle_infer.PrecisionType.Int8 and not quant: - continue - - skip_flag = False - for skip_info in self.skip_cases: - if skip_info[0](prog_config, pred_config): - skip_flag = True - if skip_info[1] == SkipReasons.TRT_NOT_IMPLEMENTED: - self.skip_log("[TRT_NOT_IMPLEMENTED] " + skip_info[ - 2] + ' ' + repr(prog_config) + ' vs ' + self. - inference_config_str(pred_config)) - elif skip_info[1] == SkipReasons.TRT_NOT_SUPPORT: - self.skip_log("[TRT_NOT_SUPPORT] " + skip_info[ - 2] + ' ' + repr(prog_config) + ' vs ' + self. - inference_config_str(pred_config)) - else: - raise NotImplementedError - break - - try: - pred_config_deserialize = paddle_infer.Config(pred_config) - results.append( - self.run_test_config(model, params, prog_config, - pred_config, feed_data)) - self.assert_tensors_near(atol, rtol, results[-1], - results[0]) - if not skip_flag: - self.assert_op_size(nodes_num[0], nodes_num[1]) - # deserialize test - if nodes_num[0] > 0: - self.run_test_config(model, params, prog_config, - pred_config_deserialize, feed_data) - except Exception as e: - self.fail_log( - str(prog_config) + ' vs ' + self.inference_config_str( - pred_config) + - '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))) - if not skip_flag: - status = False - continue - - self.success_log('RUN ' + str(prog_config) + ' vs ' + - self.inference_config_str(pred_config)) - - self.assertTrue(status) +from auto_scan_test import TrtLayerAutoScanTest, SkipReasons diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index af2203316d8b3b15bac96f94297a3653a4b9ab60..fe8382faa0c34c600c7a228baac34803ecc1492e 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -3,6 +3,7 @@ coverage pycrypto ; platform_system != "Windows" mock gym +hypothesis opencv-python<=4.2.0.32 visualdl paddle2onnx>=0.8.2