From 814f1d655ac45bbe9f37fa4abef0de519681d5de Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Fri, 26 May 2023 15:10:34 +0800 Subject: [PATCH] fix fp16 io (#54129) * fix fp16 io * disable precision test --- .../fluid/framework/ir/auto_mixed_precision_pass.cc | 1 + test/ir/inference/test_trt_inference_fp16_io.py | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index ea0f7c47fb6..f30166ad62f 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -509,6 +509,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { // when op_1 only support cpu kernel. if op_2's intput var is op_1's // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && + GetOpOriginalType(op_type) != "tensorrt_engine" && !KernelSupportPrecision( GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) { for (auto* out_var_node : op_node->outputs) { diff --git a/test/ir/inference/test_trt_inference_fp16_io.py b/test/ir/inference/test_trt_inference_fp16_io.py index 8555369e474..0f30090a324 100644 --- a/test/ir/inference/test_trt_inference_fp16_io.py +++ b/test/ir/inference/test_trt_inference_fp16_io.py @@ -66,10 +66,11 @@ class TestEnableLowPrecisionIO: fp32_output = self.get_fp32_output() fp16_output = self.get_fp16_output() - # np.testing.assert_allclose( - # fp32_output.numpy().flatten(), - # fp16_output.numpy().flatten(), - # ) + # if os.name == 'posix': + # np.testing.assert_allclose( + # fp32_output.numpy().flatten(), + # fp16_output.numpy().flatten(), + # ) class TestEnableLowPrecisionIOWithGPU( @@ -105,6 +106,7 @@ class TestEnableLowPrecisionIOWithTRTAllGraph( use_static=False, use_calib_mode=False, ) + config.enable_tuned_tensorrt_dynamic_shape() config.enable_memory_optim() config.enable_low_precision_io(low_precision_io) config.disable_glog_info() @@ -129,6 +131,7 @@ class TestEnableLowPrecisionIOWithTRTSubGraph( use_static=False, use_calib_mode=False, ) + config.enable_tuned_tensorrt_dynamic_shape() config.enable_memory_optim() config.enable_low_precision_io(low_precision_io) config.exp_disable_tensorrt_ops(["flatten_contiguous_range"]) -- GitLab