diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 4aa400afbf810f5c6c5861a95c01c91b0ff61ab6..bdc07efbc0f8c45532ea67dfaaa3f812679018bf 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +DECLARE_bool(use_mkldnn); + namespace paddle { namespace operators { @@ -30,6 +32,9 @@ const char ConditionalOp::kCondition[] = "Cond"; const char ConditionalOp::kScope[] = "Scope"; const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; +using Executor = framework::Executor; +using ExecutorPrepareContext = framework::ExecutorPrepareContext; + class ConditionalBlockOp : public ConditionalOp { public: ConditionalBlockOp(const std::string &type, @@ -76,22 +81,28 @@ class ConditionalBlockOp : public ConditionalOp { // Executors (executors declared inside control ops) platform::DontClearMKLDNNCache(dev_place); #endif - framework::Executor exec(dev_place); auto *block = Attr("sub_block"); VLOG(3) << "Conditional block.idx = " << block->ID() << ", scope = " << &cur_scope; auto &skip_vars = Attr>(ConditionalOp::kSkipEagerDeletionVars); - exec.Run(*block->Program(), - &cur_scope, - block->ID(), - false, - true, - skip_vars, - /* force_disable_gc */ false, - /* keep_kid_scopes */ true); + if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) { + auto &pdesc = *block->Program(); + exec.reset(new Executor(dev_place)); + if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); + ctx = exec->Prepare(pdesc, block->ID(), skip_vars, false); +#ifdef PADDLE_WITH_MKLDNN + platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); + platform::RegisterModelLayout(ctx->ops_, dev_place); +#endif + } + exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, true); } } + + private: + mutable std::shared_ptr exec{nullptr}; + mutable std::unique_ptr ctx{nullptr}; }; class ConditionalBlockInferShape : public framework::InferShapeBase { @@ -152,19 +163,21 @@ class ConditionalBlockGradOp : public ConditionalOp { scopes.size())); framework::Scope &cur_scope = *scopes[0]; - framework::Executor exec(dev_place); auto *block = Attr("sub_block"); VLOG(3) << "Conditional Grad block.idx = " << block->ID() << ", scope = " << &cur_scope; - exec.Run(*block->Program(), - &cur_scope, - block->ID(), - false, - true, - inside_grads, - /* force_disable_gc */ false, - /* keep_kid_scopes */ false); + if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) { + auto &pdesc = *block->Program(); + exec.reset(new Executor(dev_place)); + if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); + ctx = exec->Prepare(pdesc, block->ID(), inside_grads, false); +#ifdef PADDLE_WITH_MKLDNN + platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); + platform::RegisterModelLayout(ctx->ops_, dev_place); +#endif + } + exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, false); AssignLocalGradientToParentScope( dev_place, cur_scope, scope, inside_grads, outside_grads, inputs); @@ -174,6 +187,10 @@ class ConditionalBlockGradOp : public ConditionalOp { AssignZeroToParentScope(dev_place, scope, inputs, outside_grads); } + private: + mutable std::shared_ptr exec{nullptr}; + mutable std::unique_ptr ctx{nullptr}; + private: void AssignLocalGradientToParentScope( const platform::Place &place,