diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index ea0f7c47fb68c3cde1429d4275336a6ea84689aa..f30166ad62ff8fec20cb6bd51efa1405fd3a44f8 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -509,6 +509,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { // when op_1 only support cpu kernel. if op_2's intput var is op_1's // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && + GetOpOriginalType(op_type) != "tensorrt_engine" && !KernelSupportPrecision( GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) { for (auto* out_var_node : op_node->outputs) { diff --git a/test/ir/inference/test_trt_inference_fp16_io.py b/test/ir/inference/test_trt_inference_fp16_io.py index 8555369e47459091078a9f2e16de1adcaa309903..0f30090a3249344373f250f2d52edddde52f1060 100644 --- a/test/ir/inference/test_trt_inference_fp16_io.py +++ b/test/ir/inference/test_trt_inference_fp16_io.py @@ -66,10 +66,11 @@ class TestEnableLowPrecisionIO: fp32_output = self.get_fp32_output() fp16_output = self.get_fp16_output() - # np.testing.assert_allclose( - # fp32_output.numpy().flatten(), - # fp16_output.numpy().flatten(), - # ) + # if os.name == 'posix': + # np.testing.assert_allclose( + # fp32_output.numpy().flatten(), + # fp16_output.numpy().flatten(), + # ) class TestEnableLowPrecisionIOWithGPU( @@ -105,6 +106,7 @@ class TestEnableLowPrecisionIOWithTRTAllGraph( use_static=False, use_calib_mode=False, ) + config.enable_tuned_tensorrt_dynamic_shape() config.enable_memory_optim() config.enable_low_precision_io(low_precision_io) config.disable_glog_info() @@ -129,6 +131,7 @@ class TestEnableLowPrecisionIOWithTRTSubGraph( use_static=False, use_calib_mode=False, ) + config.enable_tuned_tensorrt_dynamic_shape() config.enable_memory_optim() config.enable_low_precision_io(low_precision_io) config.exp_disable_tensorrt_ops(["flatten_contiguous_range"])