From d52251450bcfb04c1f6fdb2b0b14c46d6f2814f7 Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 7 Jun 2021 20:55:02 +0800 Subject: [PATCH] Fix inference prepare data (#33370) --- paddle/fluid/framework/operator.cc | 7 +++- .../fluid/inference/api/analysis_predictor.cc | 39 +++++++++++++++++++ .../ir/inference/test_trt_conv_pass.py | 2 +- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1e26dab629..ac4d5a97cf 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1525,7 +1525,12 @@ Scope* OperatorWithKernel::PrepareData( // the rest iterations to save the elapsed time. // We do not support skipping PrepareData in while block, because the Op's // input may be changed by subsequent Ops, which may cause an error. - if (pre_scope_ == &scope && new_scope == nullptr) { + + // For inference, ops that behind conditional branch aren't supported well, + // so disable prepare optimization conservatively. + bool force_prepare_data = HasAttr("inference_force_prepare_data") && + Attr("inference_force_prepare_data"); + if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) { need_prepare_data_ = false; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 89c8c7902b..e49b33da9c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() { executor_.reset(new paddle::framework::NaiveExecutor(place_)); return true; } + +static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) { + // here is prepare data optimization related bad cases: + // let's assume an op behind conditional_block and if conditional_block + // chooses branch 1, the op need to call prepare data. else the op don't need + // to call prepare data. In running, if predictor chooses branch 2, then + // optimization takes effect, later issue is followed if predictor chooses + // branch 1, because the op lost chance to prepare data. + std::vector op_type = {"conditional_block_infer", + "select_input"}; + for (const auto &type : op_type) { + if (op->Type() == type) { + return true; + } + } + return false; +} + +static void DisablePrepareDataOpt( + std::shared_ptr inference_program, int block, + bool pre_disable_opt) { + bool disable_opt = false; + auto &infer_block = inference_program->Block(block); + for (auto *op : infer_block.AllOps()) { + if (disable_opt || pre_disable_opt) { + op->SetAttr("inference_force_prepare_data", true); + } + if (op->HasAttr("sub_block")) { + int blockID = op->GetBlockAttrId("sub_block"); + DisablePrepareDataOpt(inference_program, blockID, + disable_opt || pre_disable_opt); + } + // disable prepare data if unfriendly op is found + disable_opt = IsPrepareDataOptTargetOp(op); + } +} + bool AnalysisPredictor::PrepareExecutor() { + DisablePrepareDataOpt(inference_program_, 0, false); + executor_->Prepare(sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py index 7f613c4765..adbb89523a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py @@ -195,7 +195,7 @@ class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest): }, { "conv2d_0.tmp_0": [16, 6, 16, 16], "data": [16, 6, 16, 16], - "depthwise_conv2d_0.tmp_0": [32, 6, 64, 64] + "depthwise_conv2d_0.tmp_0": [16, 6, 16, 16] }, False) self.fetch_list = [conv_out] -- GitLab