diff --git a/paddle/fluid/inference/analysis/helper.cc b/paddle/fluid/inference/analysis/helper.cc index 008608c14c75cbc0ee37baa57c6f0cbebc5bc064..368ef2e5583fe2f6fcb24c98ded02f4e5325f7a4 100644 --- a/paddle/fluid/inference/analysis/helper.cc +++ b/paddle/fluid/inference/analysis/helper.cc @@ -75,6 +75,18 @@ void SetAttr>(framework::proto::OpDesc *op, } } +template <> +void SetAttr>(framework::proto::OpDesc *op, + const std::string &name, + const std::vector &data) { + auto *attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::LONGS); + for (const auto i : data) { + attr->add_longs(i); + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bbe68a7bab5c64f6c89d466a4a83ea326adcb371..8d696e448e2e131ba77936afa593cc4a5d91007e 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -208,6 +208,15 @@ void TensorRtSubgraphPass::CreateTensorRTOp( SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "parameters", params); + // we record all inputs' shapes in attr to check if they are consistent + // with the real inputs' shapes retrieved from scope when trt runs. + for (auto *x : node->inputs) { + if (x->IsVar() && x->Var()) { + framework::VarDesc *var = x->Var(); + SetAttr(op_desc->Proto(), var->Name() + "_shape", var->GetShape()); + } + } + auto use_static_engine = Get("use_static_engine"); // TODO(NHZlX) // There are models with the same structure but the different parameters, diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 41492979cd8b912bb5851724e5c71b4989871c20..22c0c9e9d4af3233e3cc712d26d13eb38b302abf 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -183,7 +183,8 @@ class TensorRTEngineOp : public framework::OperatorBase { auto stream = reinterpret_cast(dev_ctx).stream(); - PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); + PADDLE_ENFORCE_EQ(input_names_.empty(), false, + "should pass at least one input"); std::vector output_maps = Attr>("output_name_mapping"); @@ -203,7 +204,21 @@ class TensorRTEngineOp : public framework::OperatorBase { // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x); - auto t_shape = framework::vectorize(t.dims()); + auto t_shape = framework::vectorize(t.dims()); + // check if the input shapes are consistent with model. + if (HasAttr(x + "_shape")) { + std::vector i_shape = Attr>(x + "_shape"); + std::vector model_input_shape(i_shape.begin() + 1, + i_shape.end()); + std::vector runtime_input_shape(t_shape.begin() + 1, + t_shape.end()); + PADDLE_ENFORCE_EQ(model_input_shape == runtime_input_shape, true, + "Input shapes are inconsistent with the model. TRT 5 " + "or lower version " + "does not support dynamic input shapes. Please check " + "your input shapes."); + } + runtime_batch = t_shape[0]; const int bind_index = engine->engine()->getBindingIndex(x.c_str());