From baccd7e2cac7b7e344527170b66b628460eeac9b Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Sat, 21 Sep 2019 09:18:46 +0800 Subject: [PATCH] Add TRT input shape check between model and runtime (#19864) * add TRT shape check, test=develop * model_input_shape == runtime_input_shape, refine message, test=develop --- paddle/fluid/inference/analysis/helper.cc | 12 ++++++++++++ .../ir_passes/tensorrt_subgraph_pass.cc | 9 +++++++++ .../operators/tensorrt/tensorrt_engine_op.h | 19 +++++++++++++++++-- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/helper.cc b/paddle/fluid/inference/analysis/helper.cc index 008608c14c7..368ef2e5583 100644 --- a/paddle/fluid/inference/analysis/helper.cc +++ b/paddle/fluid/inference/analysis/helper.cc @@ -75,6 +75,18 @@ void SetAttr>(framework::proto::OpDesc *op, } } +template <> +void SetAttr>(framework::proto::OpDesc *op, + const std::string &name, + const std::vector &data) { + auto *attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::LONGS); + for (const auto i : data) { + attr->add_longs(i); + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bbe68a7bab5..8d696e448e2 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -208,6 +208,15 @@ void TensorRtSubgraphPass::CreateTensorRTOp( SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "parameters", params); + // we record all inputs' shapes in attr to check if they are consistent + // with the real inputs' shapes retrieved from scope when trt runs. + for (auto *x : node->inputs) { + if (x->IsVar() && x->Var()) { + framework::VarDesc *var = x->Var(); + SetAttr(op_desc->Proto(), var->Name() + "_shape", var->GetShape()); + } + } + auto use_static_engine = Get("use_static_engine"); // TODO(NHZlX) // There are models with the same structure but the different parameters, diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 41492979cd8..22c0c9e9d4a 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -183,7 +183,8 @@ class TensorRTEngineOp : public framework::OperatorBase { auto stream = reinterpret_cast(dev_ctx).stream(); - PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); + PADDLE_ENFORCE_EQ(input_names_.empty(), false, + "should pass at least one input"); std::vector output_maps = Attr>("output_name_mapping"); @@ -203,7 +204,21 @@ class TensorRTEngineOp : public framework::OperatorBase { // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x); - auto t_shape = framework::vectorize(t.dims()); + auto t_shape = framework::vectorize(t.dims()); + // check if the input shapes are consistent with model. + if (HasAttr(x + "_shape")) { + std::vector i_shape = Attr>(x + "_shape"); + std::vector model_input_shape(i_shape.begin() + 1, + i_shape.end()); + std::vector runtime_input_shape(t_shape.begin() + 1, + t_shape.end()); + PADDLE_ENFORCE_EQ(model_input_shape == runtime_input_shape, true, + "Input shapes are inconsistent with the model. TRT 5 " + "or lower version " + "does not support dynamic input shapes. Please check " + "your input shapes."); + } + runtime_batch = t_shape[0]; const int bind_index = engine->engine()->getBindingIndex(x.c_str()); -- GitLab