From 225a9c4ed869a87949e6a6e94ab56473cc8d9e03 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Mon, 7 Dec 2020 16:41:52 +0800 Subject: [PATCH] Fix unittest (#29412) * fix tensorrt unittest precision error * fix unittest precision error. test_trt_subgraph_pass && test_trt_dynamic_shape_transformer_prune --- .../tests/api/trt_dynamic_shape_transformer_prune_test.cc | 2 +- .../tests/unittests/ir/inference/test_trt_subgraph_pass.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc index 3916cf361c..965e233b68 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc @@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector result) { run(config, &out_data); for (size_t i = 0; i < out_data.size(); i++) { - EXPECT_NEAR(result[i], out_data[i], 1e-4); + EXPECT_NEAR(result[i], out_data[i], 2e-3); } } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index 73fec1f771..77457efa39 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -308,7 +308,10 @@ class TensorRTSubgraphPassActivationTest(InferencePassTest): use_gpu = True if os.path.exists(self.path + "_opt_cache"): shutil.rmtree(self.path + "_opt_cache") - self.check_output_with_option(use_gpu) + if self.trt_parameters.precision == AnalysisConfig.Precision.Float32: + self.check_output_with_option(use_gpu) + else: + self.check_output_with_option(use_gpu, 1e-3) self.assertTrue( PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) @@ -572,7 +575,7 @@ class TensorRTSubgraphPassDynamicSplitFp16SerializeTest(InferencePassTest): use_gpu = True if os.path.exists(self.path + "_opt_cache"): shutil.rmtree(self.path + "_opt_cache") - self.check_output_with_option(use_gpu) + self.check_output_with_option(use_gpu, 1e-3) self.assertTrue( PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) -- GitLab