未验证 提交 ae241565 编写于 作者: Y Yuanle Liu 提交者: GitHub

Fix inference fp16 io (#54042)

* fix trt inference fp16 io

* fix inference fp16 io
上级 d89e0367
...@@ -359,7 +359,7 @@ class OpConverter { ...@@ -359,7 +359,7 @@ class OpConverter {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The output tensor in TensorRT subgraph should be LoDTensor")); "The output tensor in TensorRT subgraph should be LoDTensor"));
nvinfer1::DataType out_dtype = FluidDataType2TRT(var->GetDataType()); nvinfer1::DataType out_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->WithFp16() && !engine->WithInt8() && if (engine->precision() == phi::DataType::FLOAT16 &&
out_dtype == nvinfer1::DataType::kFLOAT && out_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) { engine->EnableLowPrecisionIO()) {
out_dtype = nvinfer1::DataType::kHALF; out_dtype = nvinfer1::DataType::kHALF;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册