未验证 提交 baccd7e2 编写于 作者: P Pei Yang 提交者: GitHub

Add TRT input shape check between model and runtime (#19864)

* add TRT shape check, test=develop

* model_input_shape == runtime_input_shape, refine message, test=develop
上级 74812d1c
......@@ -75,6 +75,18 @@ void SetAttr<std::vector<int>>(framework::proto::OpDesc *op,
}
}
template <>
void SetAttr<std::vector<int64_t>>(framework::proto::OpDesc *op,
const std::string &name,
const std::vector<int64_t> &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::LONGS);
for (const auto i : data) {
attr->add_longs(i);
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -208,6 +208,15 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
SetAttr(op_desc->Proto(), "parameters", params);
// we record all inputs' shapes in attr to check if they are consistent
// with the real inputs' shapes retrieved from scope when trt runs.
for (auto *x : node->inputs) {
if (x->IsVar() && x->Var()) {
framework::VarDesc *var = x->Var();
SetAttr(op_desc->Proto(), var->Name() + "_shape", var->GetShape());
}
}
auto use_static_engine = Get<bool>("use_static_engine");
// TODO(NHZlX)
// There are models with the same structure but the different parameters,
......
......@@ -183,7 +183,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs");
PADDLE_ENFORCE_EQ(input_names_.empty(), false,
"should pass at least one input");
std::vector<std::string> output_maps =
Attr<std::vector<std::string>>("output_name_mapping");
......@@ -203,7 +204,21 @@ class TensorRTEngineOp : public framework::OperatorBase {
// convert input and copy to TRT engine's buffer
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
auto t_shape = framework::vectorize(t.dims());
auto t_shape = framework::vectorize<int64_t>(t.dims());
// check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) {
std::vector<int64_t> i_shape = Attr<std::vector<int64_t>>(x + "_shape");
std::vector<int64_t> model_input_shape(i_shape.begin() + 1,
i_shape.end());
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
t_shape.end());
PADDLE_ENFORCE_EQ(model_input_shape == runtime_input_shape, true,
"Input shapes are inconsistent with the model. TRT 5 "
"or lower version "
"does not support dynamic input shapes. Please check "
"your input shapes.");
}
runtime_batch = t_shape[0];
const int bind_index = engine->engine()->getBindingIndex(x.c_str());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册