From fcb9c0b55189bf2b285803fbb01a8256d3f3ad66 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 19 Oct 2022 11:02:20 +0800 Subject: [PATCH] [ cherrypick] Construct exec and ctx only once in cond op to speed up (#47012) Construct exec and ctx only once in cond op to speed up --- .../controlflow/conditional_block_infer_op.cc | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc index 2ddcc7eb72..cb52baa7e6 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -14,6 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/conditional_block_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { class OpDesc; @@ -73,14 +78,33 @@ class ConditionalBlockInferOp : public ConditionalOp { scopes->front() = &scope.NewScope(); auto &cur_scope = *scopes->front(); - framework::Executor exec(dev_place); auto *block = Attr("sub_block"); VLOG(3) << "Conditional block.idx = " << block->ID() << ", scope = " << &cur_scope; - exec.Run(*block->Program(), &cur_scope, block->ID(), false); + + if (!exec_ || !platform::is_same_place(exec_->GetPlace(), dev_place)) { + auto &pdesc = *block->Program(); + exec_.reset(new framework::Executor(dev_place)); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) exec_->EnableMKLDNN(pdesc); +#endif + ctx_ = exec_->Prepare( + pdesc, block->ID(), std::vector(), false); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + platform::AttachPointerHashToMKLDNNKey(exec_.get(), dev_place); + platform::RegisterModelLayout(ctx_->ops_, dev_place); + } +#endif + } + exec_->RunPreparedContext(ctx_.get(), &cur_scope, false, true, false); scope.DeleteScope(scopes->front()); } } + + private: + mutable std::shared_ptr exec_{nullptr}; + mutable std::unique_ptr ctx_{nullptr}; }; } // namespace operators -- GitLab