From ae24156506cdcc07be43f66b5034c4ee43f396f3 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Tue, 23 May 2023 15:01:43 +0800 Subject: [PATCH] Fix inference fp16 io (#54042) * fix trt inference fp16 io * fix inference fp16 io --- paddle/fluid/inference/tensorrt/convert/op_converter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 25af70f9d9c..018e3d2bb77 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -359,7 +359,7 @@ class OpConverter { platform::errors::InvalidArgument( "The output tensor in TensorRT subgraph should be LoDTensor")); nvinfer1::DataType out_dtype = FluidDataType2TRT(var->GetDataType()); - if (engine->WithFp16() && !engine->WithInt8() && + if (engine->precision() == phi::DataType::FLOAT16 && out_dtype == nvinfer1::DataType::kFLOAT && engine->EnableLowPrecisionIO()) { out_dtype = nvinfer1::DataType::kHALF; -- GitLab