未验证 提交 a2afcace 编写于 作者: W wenbin 提交者: GitHub

add trt error information. (#35277)

* add trt error information.

* rerun ci
上级 3d76d003
......@@ -266,7 +266,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
void RunTrt(const framework::Scope &scope, const platform::Place &dev_place,
TensorRTEngine *engine) const {
int runtime_batch = 1;
int runtime_batch = -1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
......@@ -297,7 +297,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
auto t_shape = framework::vectorize<int64_t>(t.dims());
runtime_batch = t_shape[0];
const int bind_index = engine->engine()->getBindingIndex(x.c_str());
PADDLE_ENFORCE_LT(
bind_index, num_bindings,
......@@ -317,6 +316,28 @@ class TensorRTEngineOp : public framework::OperatorBase {
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
t_shape.end());
RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape);
if (runtime_batch != -1) {
PADDLE_ENFORCE_EQ(
runtime_batch, t_shape[0],
platform::errors::InvalidArgument(
"Inputs of trt subgraphs has different batchsize. "
"It's not allowed in static shape mode. "
"Check whether the model you are running has multiple trt "
"subgraphs: \n "
"\tIf there are multiple trt subgraphs, you need to ensure "
"that the first dimension of the input tensor of these "
"subgraphs is "
"consistent.\n"
"\tIf there are inconsistent subgraphs, you need to filter "
"them "
"by "
"setting min_subgraph_size using EnableTensorrtEngine "
"interface.\n"
"\tThe min_subgraph_size shouble to be greater than the "
"number "
"of "
"nodes in the inconsistent subgraph.\n"));
}
}
} else {
#if IS_TRT_VERSION_GE(6000)
......@@ -341,6 +362,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true));
#endif
}
runtime_batch = t_shape[0];
auto type = t.type();
if (type == framework::proto::VarType::FP32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册