From dd181238f7e8cb84f3ed66ed270d7bf620dd4117 Mon Sep 17 00:00:00 2001 From: wenbin Date: Fri, 4 Jun 2021 18:43:14 +0800 Subject: [PATCH] fix inference prepare data bug (#33305) * fix inference prepare data bug * rename functions * typo * typo * typo * UT correct * correct condition * correct condition * ci coverage * morelines * fix ci coverage --- paddle/fluid/framework/operator.cc | 7 +++- .../fluid/inference/api/analysis_predictor.cc | 39 +++++++++++++++++++ .../controlflow/conditional_block_infer_op.cc | 2 + .../ir/inference/test_trt_conv_pass.py | 2 +- 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 25d430df45..20cffaa959 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1531,7 +1531,12 @@ Scope* OperatorWithKernel::PrepareData( // the rest iterations to save the elapsed time. // We do not support skipping PrepareData in while block, because the Op's // input may be changed by subsequent Ops, which may cause an error. - if (pre_scope_ == &scope && new_scope == nullptr) { + + // For inference, ops that behind conditional branch aren't supported well, + // so disable prepare optimization conservatively. + bool force_prepare_data = HasAttr("inference_force_prepare_data") && + Attr("inference_force_prepare_data"); + if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) { need_prepare_data_ = false; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2733d21b6c..dc075b9f79 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() { executor_.reset(new paddle::framework::NaiveExecutor(place_)); return true; } + +static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) { + // here is prepare data optimization related bad cases: + // let's assume an op behind conditional_block and if conditional_block + // chooses branch 1, the op need to call prepare data. else the op don't need + // to call prepare data. In running, if predictor chooses branch 2, then + // optimization takes effect, later issue is followed if predictor chooses + // branch 1, because the op lost chance to prepare data. + std::vector op_type = {"conditional_block_infer", + "select_input"}; + for (const auto &type : op_type) { + if (op->Type() == type) { + return true; + } + } + return false; +} + +static void DisablePrepareDataOpt( + std::shared_ptr inference_program, int block, + bool pre_disable_opt) { + bool disable_opt = false; + auto &infer_block = inference_program->Block(block); + for (auto *op : infer_block.AllOps()) { + if (disable_opt || pre_disable_opt) { + op->SetAttr("inference_force_prepare_data", true); + } + if (op->HasAttr("sub_block")) { + int blockID = op->GetBlockAttrId("sub_block"); + DisablePrepareDataOpt(inference_program, blockID, + disable_opt || pre_disable_opt); + } + // disable prepare data if unfriendly op is found + disable_opt = IsPrepareDataOptTargetOp(op); + } +} + bool AnalysisPredictor::PrepareExecutor() { + DisablePrepareDataOpt(inference_program_, 0, false); + executor_->Prepare(sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc index 62019be26c..6705d42bcd 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -73,6 +73,8 @@ class ConditionalBlockInferOp : public ConditionalOp { framework::Executor exec(dev_place); auto *block = Attr("sub_block"); + VLOG(3) << "Conditional block.idx = " << block->ID() + << ", scope = " << &cur_scope; exec.Run(*block->Program(), &cur_scope, block->ID(), false); scope.DeleteScope(scopes->front()); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py index 269183f144..ebbf724d0b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py @@ -219,7 +219,7 @@ class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest): }, { "conv2d_0.tmp_0": [16, 6, 16, 16], "data": [16, 6, 16, 16], - "depthwise_conv2d_0.tmp_0": [32, 6, 64, 64] + "depthwise_conv2d_0.tmp_0": [16, 6, 16, 16] }, False) self.fetch_list = [conv_out] -- GitLab