diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 69b96bca43a5e37e28b8427b6534e3e505989351..984d06efa5ca2bb14da1f9a505cb18610e963958 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -24,6 +24,7 @@ #include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace inference { @@ -183,6 +184,35 @@ inline std::string Vec2Str(const std::vector& vec) { os << vec[vec.size() - 1] << ")"; return os.str(); } + +static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) { + nvinfer1::DataType nv_type = nvinfer1::DataType::kFLOAT; + switch (type) { + case phi::DataType::FLOAT32: + nv_type = nvinfer1::DataType::kFLOAT; + break; + case phi::DataType::FLOAT16: + nv_type = nvinfer1::DataType::kHALF; + break; + case phi::DataType::INT32: + case phi::DataType::INT64: + nv_type = nvinfer1::DataType::kINT32; + break; + case phi::DataType::INT8: + nv_type = nvinfer1::DataType::kINT8; + break; +#if IS_TRT_VERSION_GE(7000) + case phi::DataType::BOOL: + nv_type = nvinfer1::DataType::kBOOL; + break; +#endif + default: + paddle::platform::errors::InvalidArgument( + "phi::DataType not supported data type %s.", type); + break; + } + return nv_type; +} } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 9b05faf8df47af016ff79cfe8092da935a489d93..03971a3aa7c3ba09e693fd2cce495b994e14a0db 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -562,6 +562,14 @@ class TensorRTEngineOp : public framework::OperatorBase { } runtime_batch = t_shape[0]; VLOG(1) << "trt input [" << x << "] dtype is " << t.dtype(); + auto indata_type = inference::tensorrt::PhiType2NvType(t.dtype()); + auto intrt_index = engine->engine()->getBindingIndex(x.c_str()); + auto intrt_type = engine->engine()->getBindingDataType(intrt_index); + PADDLE_ENFORCE_EQ(indata_type, + intrt_type, + platform::errors::InvalidArgument( + "The TRT Engine OP's input type should equal " + "to the input data type")); auto type = framework::TransToProtoVarType(t.dtype()); if (type == framework::proto::VarType::FP32) { buffers[bind_index] = static_cast(t.data());