未验证 提交 7c1bc000 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix trt inference fp16 io (#54032)

上级 07223e34
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -305,7 +306,7 @@ class OpConverter {
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
nvinfer1::DataType in_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->WithFp16() && !engine->WithInt8() &&
if (engine->precision() == phi::DataType::FLOAT16 &&
in_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) {
in_dtype = nvinfer1::DataType::kHALF;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册