未验证 提交 8a239ae5 编写于 作者: W Wilber 提交者: GitHub

trt engine dtor when the last predictor dtor. (#35842)

上级 71f051fb
......@@ -135,7 +135,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
}
// Create an FC Node.
OpDesc desc;
OpDesc desc(mul->Op()->Block());
desc.SetType("fc");
// Set inputs of fc
......
......@@ -220,7 +220,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
......@@ -299,7 +299,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
......@@ -441,7 +441,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
......@@ -526,7 +526,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
......
......@@ -178,7 +178,7 @@ class OpDesc {
}
proto::OpDesc desc_;
BlockDesc *block_; // not_own
BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names
VariableNameMap inputs_;
// output arg name => output variable names
......
......@@ -645,7 +645,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
inference_program_.reset(
new framework::ProgramDesc(argument_.ir_analyzed_program()));
new framework::ProgramDesc(argument_.ir_analyzed_program()),
[](framework::ProgramDesc *prog) {
// Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads.
#if PADDLE_WITH_TENSORRT
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteAll();
#endif
delete prog;
});
// The config and argument take a lot of storage,
// when the predictor settings are complete, we release these stores.
argument_.PartiallyRelease();
......
......@@ -159,6 +159,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "relu" || op_type == "relu6" || op_type == "tanh" ||
op_type == "sigmoid") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -274,6 +280,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "matmul") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name);
......@@ -324,6 +336,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (axis[0] == 0 && axis.size() == 2) return false;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -372,6 +390,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
} else {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
......@@ -385,6 +409,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape) return false;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto index_var_name = desc.Input("Index")[0];
auto* x_var_desc = block->FindVar(x_var_name);
......@@ -428,6 +458,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (data_layout != framework::DataLayout::kNCHW) return false;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -439,6 +475,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "multiclass_nms") {
if (with_dynamic_shape) return false;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name);
......@@ -598,6 +640,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -657,6 +705,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -724,6 +778,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
const auto x_shape = x_var_desc->GetShape();
......@@ -775,6 +835,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -856,6 +922,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
std::vector<int64_t> shape;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name);
......@@ -881,6 +953,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "scale") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -892,6 +970,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "swish") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......@@ -916,6 +1000,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* var_desc = block->FindVar(desc.Input("Alpha")[0]);
if (!var_desc) {
VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
......@@ -1051,6 +1141,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
......
......@@ -1161,8 +1161,8 @@ function parallel_test_base_gpu() {
EOF
set -x
# set trt_convert ut to run 30% cases.
export TEST_NUM_PERCENT_CASES=0.3
# set trt_convert ut to run 15% cases.
export TEST_NUM_PERCENT_CASES=0.15
precison_cases=""
bash $PADDLE_ROOT/tools/check_added_ut.sh
if [ ${PRECISION_TEST:-OFF} == "ON" ]; then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册