From 2814d7f678642d51eb8bbc42102122d82c0fe6f8 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 19 Oct 2022 10:55:06 +0800 Subject: [PATCH] Construct exec and ctx only once in cond op to speed up (#47092) * cond infer apply exec seprate * fix bugs * fix as comment --- .../controlflow/conditional_block_infer_op.cc | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc index e83b0007f0..44996e7e4b 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -82,25 +82,29 @@ class ConditionalBlockInferOp : public ConditionalOp { VLOG(3) << "Conditional block.idx = " << block->ID() << ", scope = " << &cur_scope; - if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) { + if (!exec_ || !platform::is_same_place(exec_->GetPlace(), dev_place)) { auto &pdesc = *block->Program(); - exec.reset(new framework::Executor(dev_place)); - if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); - ctx = exec->Prepare( + exec_.reset(new framework::Executor(dev_place)); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) exec_->EnableMKLDNN(pdesc); +#endif + ctx_ = exec_->Prepare( pdesc, block->ID(), std::vector(), false); #ifdef PADDLE_WITH_MKLDNN - platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); - platform::RegisterModelLayout(ctx->ops_, dev_place); + if (FLAGS_use_mkldnn) { + platform::AttachPointerHashToMKLDNNKey(exec_.get(), dev_place); + platform::RegisterModelLayout(ctx_->ops_, dev_place); + } #endif } - exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, false); + exec_->RunPreparedContext(ctx_.get(), &cur_scope, false, true, false); scope.DeleteScope(scopes->front()); } } private: - mutable std::shared_ptr exec{nullptr}; - mutable std::unique_ptr ctx{nullptr}; + mutable std::shared_ptr exec_{nullptr}; + mutable std::unique_ptr ctx_{nullptr}; }; } // namespace operators -- GitLab