diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 1f3029d94b940fb514cc04cabe5f41b443b096dd..3b27f525b552699c3d6cc96d388cd2772346ca86 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -266,7 +266,7 @@ class TensorRTEngineOp : public framework::OperatorBase { void RunTrt(const framework::Scope &scope, const platform::Place &dev_place, TensorRTEngine *engine) const { - int runtime_batch = 1; + int runtime_batch = -1; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); auto stream = @@ -297,7 +297,6 @@ class TensorRTEngineOp : public framework::OperatorBase { auto &t = inference::analysis::GetFromScope(scope, x); auto t_shape = framework::vectorize(t.dims()); - runtime_batch = t_shape[0]; const int bind_index = engine->engine()->getBindingIndex(x.c_str()); PADDLE_ENFORCE_LT( bind_index, num_bindings, @@ -317,6 +316,28 @@ class TensorRTEngineOp : public framework::OperatorBase { std::vector runtime_input_shape(t_shape.begin() + 1, t_shape.end()); RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape); + if (runtime_batch != -1) { + PADDLE_ENFORCE_EQ( + runtime_batch, t_shape[0], + platform::errors::InvalidArgument( + "Inputs of trt subgraphs has different batchsize. " + "It's not allowed in static shape mode. " + "Check whether the model you are running has multiple trt " + "subgraphs: \n " + "\tIf there are multiple trt subgraphs, you need to ensure " + "that the first dimension of the input tensor of these " + "subgraphs is " + "consistent.\n" + "\tIf there are inconsistent subgraphs, you need to filter " + "them " + "by " + "setting min_subgraph_size using EnableTensorrtEngine " + "interface.\n" + "\tThe min_subgraph_size shouble to be greater than the " + "number " + "of " + "nodes in the inconsistent subgraph.\n")); + } } } else { #if IS_TRT_VERSION_GE(6000) @@ -341,6 +362,7 @@ class TensorRTEngineOp : public framework::OperatorBase { bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true)); #endif } + runtime_batch = t_shape[0]; auto type = t.type(); if (type == framework::proto::VarType::FP32) { buffers[bind_index] = static_cast(t.data());