未验证 提交 a3eb341e 编写于 作者: Y Yuanle Liu 提交者: GitHub

trt engine input data type should be consistent with trt input bindin… (#45103)

* trt engine input data type should be consistent with trt input bindings type

* fix some bugs

* fix some bugs

* fix some bugs
上级 34234282
......@@ -24,6 +24,7 @@
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -183,6 +184,35 @@ inline std::string Vec2Str(const std::vector<T>& vec) {
os << vec[vec.size() - 1] << ")";
return os.str();
}
static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
nvinfer1::DataType nv_type = nvinfer1::DataType::kFLOAT;
switch (type) {
case phi::DataType::FLOAT32:
nv_type = nvinfer1::DataType::kFLOAT;
break;
case phi::DataType::FLOAT16:
nv_type = nvinfer1::DataType::kHALF;
break;
case phi::DataType::INT32:
case phi::DataType::INT64:
nv_type = nvinfer1::DataType::kINT32;
break;
case phi::DataType::INT8:
nv_type = nvinfer1::DataType::kINT8;
break;
#if IS_TRT_VERSION_GE(7000)
case phi::DataType::BOOL:
nv_type = nvinfer1::DataType::kBOOL;
break;
#endif
default:
paddle::platform::errors::InvalidArgument(
"phi::DataType not supported data type %s.", type);
break;
}
return nv_type;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -562,6 +562,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
runtime_batch = t_shape[0];
VLOG(1) << "trt input [" << x << "] dtype is " << t.dtype();
auto indata_type = inference::tensorrt::PhiType2NvType(t.dtype());
auto intrt_index = engine->engine()->getBindingIndex(x.c_str());
auto intrt_type = engine->engine()->getBindingDataType(intrt_index);
PADDLE_ENFORCE_EQ(indata_type,
intrt_type,
platform::errors::InvalidArgument(
"The TRT Engine OP's input type should equal "
"to the input data type"));
auto type = framework::TransToProtoVarType(t.dtype());
if (type == framework::proto::VarType::FP32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册