未验证 提交 f72d52e7 编写于 作者: W Wilber 提交者: GitHub

[cherry-pick] trt engine dtor when the last predictor dtor (#35881)

* cherry-pick 32842
上级 edeb0ade
...@@ -42,7 +42,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -42,7 +42,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204) set(LITE_GIT_TAG 1c4698c6efd9a5f57a4f8369bd5b6374166f5ba4)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
......
...@@ -135,7 +135,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -135,7 +135,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
} }
// Create an FC Node. // Create an FC Node.
OpDesc desc; OpDesc desc(mul->Op()->Block());
desc.SetType("fc"); desc.SetType("fc");
// Set inputs of fc // Set inputs of fc
......
...@@ -220,7 +220,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -220,7 +220,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
OpDesc desc; OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()}); desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_in_y->Name()});
...@@ -299,7 +299,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -299,7 +299,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
OpDesc desc; OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()}); desc.SetInput("X", {squeeze2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_in_y->Name()});
...@@ -441,7 +441,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -441,7 +441,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
OpDesc desc; OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()}); desc.SetInput("X", {reshape2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_in_y->Name()});
...@@ -526,7 +526,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -526,7 +526,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
OpDesc desc; OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()}); desc.SetInput("X", {flatten2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_in_y->Name()});
......
...@@ -178,7 +178,7 @@ class OpDesc { ...@@ -178,7 +178,7 @@ class OpDesc {
} }
proto::OpDesc desc_; proto::OpDesc desc_;
BlockDesc *block_; // not_own BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names // input arg name => input variable names
VariableNameMap inputs_; VariableNameMap inputs_;
// output arg name => output variable names // output arg name => output variable names
......
...@@ -645,7 +645,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -645,7 +645,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program); ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(argument_.ir_analyzed_program())); new framework::ProgramDesc(argument_.ir_analyzed_program()),
[](framework::ProgramDesc *prog) {
// Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads.
#if PADDLE_WITH_TENSORRT
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteAll();
#endif
delete prog;
});
// The config and argument take a lot of storage, // The config and argument take a lot of storage,
// when the predictor settings are complete, we release these stores. // when the predictor settings are complete, we release these stores.
argument_.PartiallyRelease(); argument_.PartiallyRelease();
......
...@@ -159,6 +159,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -159,6 +159,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "relu" || op_type == "relu6" || op_type == "tanh" || if (op_type == "relu" || op_type == "relu6" || op_type == "tanh" ||
op_type == "sigmoid") { op_type == "sigmoid") {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -274,6 +280,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -274,6 +280,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "matmul") { if (op_type == "matmul") {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) { for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) { for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name); auto* var_desc = block->FindVar(var_name);
...@@ -324,6 +336,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -324,6 +336,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (axis[0] == 0 && axis.size() == 2) return false; if (axis[0] == 0 && axis.size() == 2) return false;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -372,6 +390,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -372,6 +390,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} else { } else {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) { if (x_shape.size() == 1) {
...@@ -385,6 +409,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -385,6 +409,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape) return false; if (!with_dynamic_shape) return false;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto index_var_name = desc.Input("Index")[0]; auto index_var_name = desc.Input("Index")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
...@@ -428,6 +458,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -428,6 +458,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (data_layout != framework::DataLayout::kNCHW) return false; if (data_layout != framework::DataLayout::kNCHW) return false;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -439,6 +475,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -439,6 +475,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "multiclass_nms") { if (op_type == "multiclass_nms") {
if (with_dynamic_shape) return false; if (with_dynamic_shape) return false;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) { for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) { for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name); auto* var_desc = block->FindVar(var_name);
...@@ -598,6 +640,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -598,6 +640,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -657,6 +705,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -657,6 +705,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -724,6 +778,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -724,6 +778,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]); auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -775,6 +835,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -775,6 +835,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -856,6 +922,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -856,6 +922,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
std::vector<int64_t> shape; std::vector<int64_t> shape;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) { for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) { for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name); auto* var_desc = block->FindVar(var_name);
...@@ -881,6 +953,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -881,6 +953,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "scale") { if (op_type == "scale") {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -892,6 +970,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -892,6 +970,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "swish") { if (op_type == "swish") {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -916,6 +1000,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -916,6 +1000,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* var_desc = block->FindVar(desc.Input("Alpha")[0]); auto* var_desc = block->FindVar(desc.Input("Alpha")[0]);
if (!var_desc) { if (!var_desc) {
VLOG(3) << "Variable Alpha of prelu TRT converter not found."; VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
...@@ -1051,6 +1141,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1051,6 +1141,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
......
...@@ -1161,8 +1161,8 @@ function parallel_test_base_gpu() { ...@@ -1161,8 +1161,8 @@ function parallel_test_base_gpu() {
EOF EOF
set -x set -x
# set trt_convert ut to run 30% cases. # set trt_convert ut to run 15% cases.
export TEST_NUM_PERCENT_CASES=0.3 export TEST_NUM_PERCENT_CASES=0.15
precison_cases="" precison_cases=""
bash $PADDLE_ROOT/tools/check_added_ut.sh bash $PADDLE_ROOT/tools/check_added_ut.sh
if [ ${PRECISION_TEST:-OFF} == "ON" ]; then if [ ${PRECISION_TEST:-OFF} == "ON" ]; then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册