diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 25d430df4582550946a3bae95f89aabbe979ef38..20cffaa9590196c5c54ae4f4448f04185ad0c276 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1531,7 +1531,12 @@ Scope* OperatorWithKernel::PrepareData( // the rest iterations to save the elapsed time. // We do not support skipping PrepareData in while block, because the Op's // input may be changed by subsequent Ops, which may cause an error. - if (pre_scope_ == &scope && new_scope == nullptr) { + + // For inference, ops that behind conditional branch aren't supported well, + // so disable prepare optimization conservatively. + bool force_prepare_data = HasAttr("inference_force_prepare_data") && + Attr("inference_force_prepare_data"); + if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) { need_prepare_data_ = false; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2733d21b6cba3af0a0d13ca95ff94a6898a5fa2e..dc075b9f79a92fe086bd2b653a0c597377340c0a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() { executor_.reset(new paddle::framework::NaiveExecutor(place_)); return true; } + +static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) { + // here is prepare data optimization related bad cases: + // let's assume an op behind conditional_block and if conditional_block + // chooses branch 1, the op need to call prepare data. else the op don't need + // to call prepare data. In running, if predictor chooses branch 2, then + // optimization takes effect, later issue is followed if predictor chooses + // branch 1, because the op lost chance to prepare data. + std::vector op_type = {"conditional_block_infer", + "select_input"}; + for (const auto &type : op_type) { + if (op->Type() == type) { + return true; + } + } + return false; +} + +static void DisablePrepareDataOpt( + std::shared_ptr inference_program, int block, + bool pre_disable_opt) { + bool disable_opt = false; + auto &infer_block = inference_program->Block(block); + for (auto *op : infer_block.AllOps()) { + if (disable_opt || pre_disable_opt) { + op->SetAttr("inference_force_prepare_data", true); + } + if (op->HasAttr("sub_block")) { + int blockID = op->GetBlockAttrId("sub_block"); + DisablePrepareDataOpt(inference_program, blockID, + disable_opt || pre_disable_opt); + } + // disable prepare data if unfriendly op is found + disable_opt = IsPrepareDataOptTargetOp(op); + } +} + bool AnalysisPredictor::PrepareExecutor() { + DisablePrepareDataOpt(inference_program_, 0, false); + executor_->Prepare(sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc index 62019be26cdef8214fe0e7c3e063c9387a30c91a..6705d42bcd74086e327d54fa44b9daf03efcba40 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -73,6 +73,8 @@ class ConditionalBlockInferOp : public ConditionalOp { framework::Executor exec(dev_place); auto *block = Attr("sub_block"); + VLOG(3) << "Conditional block.idx = " << block->ID() + << ", scope = " << &cur_scope; exec.Run(*block->Program(), &cur_scope, block->ID(), false); scope.DeleteScope(scopes->front()); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py index 269183f1441d22514e309fb458b3c1801a4c7b0c..ebbf724d0b4eadb3b1a2b81d71e7126b2ecd3f4d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py @@ -219,7 +219,7 @@ class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest): }, { "conv2d_0.tmp_0": [16, 6, 16, 16], "data": [16, 6, 16, 16], - "depthwise_conv2d_0.tmp_0": [32, 6, 64, 64] + "depthwise_conv2d_0.tmp_0": [16, 6, 16, 16] }, False) self.fetch_list = [conv_out]