未验证 提交 7c1bc000 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix trt inference fp16 io (#54032)

上级 07223e34
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -305,7 +306,7 @@ class OpConverter { ...@@ -305,7 +306,7 @@ class OpConverter {
platform::errors::InvalidArgument("TensorRT engine only takes " platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input")); "LoDTensor as input"));
nvinfer1::DataType in_dtype = FluidDataType2TRT(var->GetDataType()); nvinfer1::DataType in_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->WithFp16() && !engine->WithInt8() && if (engine->precision() == phi::DataType::FLOAT16 &&
in_dtype == nvinfer1::DataType::kFLOAT && in_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) { engine->EnableLowPrecisionIO()) {
in_dtype = nvinfer1::DataType::kHALF; in_dtype = nvinfer1::DataType::kHALF;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册