未验证 提交 ba653e7b 编写于 作者: W WangZhen 提交者: GitHub

Construct exec and ctx only once in cond op to speed up (#45794)

* Construct exec and ctx only once in cond op to speed up

* Fix construct function error
上级 4bbbed9a
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,6 +32,9 @@ const char ConditionalOp::kCondition[] = "Cond"; ...@@ -30,6 +32,9 @@ const char ConditionalOp::kCondition[] = "Cond";
const char ConditionalOp::kScope[] = "Scope"; const char ConditionalOp::kScope[] = "Scope";
const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
using Executor = framework::Executor;
using ExecutorPrepareContext = framework::ExecutorPrepareContext;
class ConditionalBlockOp : public ConditionalOp { class ConditionalBlockOp : public ConditionalOp {
public: public:
ConditionalBlockOp(const std::string &type, ConditionalBlockOp(const std::string &type,
...@@ -76,22 +81,28 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -76,22 +81,28 @@ class ConditionalBlockOp : public ConditionalOp {
// Executors (executors declared inside control ops) // Executors (executors declared inside control ops)
platform::DontClearMKLDNNCache(dev_place); platform::DontClearMKLDNNCache(dev_place);
#endif #endif
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID() VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope; << ", scope = " << &cur_scope;
auto &skip_vars = auto &skip_vars =
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars); Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
exec.Run(*block->Program(), if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
&cur_scope, auto &pdesc = *block->Program();
block->ID(), exec.reset(new Executor(dev_place));
false, if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
true, ctx = exec->Prepare(pdesc, block->ID(), skip_vars, false);
skip_vars, #ifdef PADDLE_WITH_MKLDNN
/* force_disable_gc */ false, platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place);
/* keep_kid_scopes */ true); platform::RegisterModelLayout(ctx->ops_, dev_place);
#endif
}
exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, true);
} }
} }
private:
mutable std::shared_ptr<Executor> exec{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr};
}; };
class ConditionalBlockInferShape : public framework::InferShapeBase { class ConditionalBlockInferShape : public framework::InferShapeBase {
...@@ -152,19 +163,21 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -152,19 +163,21 @@ class ConditionalBlockGradOp : public ConditionalOp {
scopes.size())); scopes.size()));
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional Grad block.idx = " << block->ID() VLOG(3) << "Conditional Grad block.idx = " << block->ID()
<< ", scope = " << &cur_scope; << ", scope = " << &cur_scope;
exec.Run(*block->Program(), if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
&cur_scope, auto &pdesc = *block->Program();
block->ID(), exec.reset(new Executor(dev_place));
false, if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
true, ctx = exec->Prepare(pdesc, block->ID(), inside_grads, false);
inside_grads, #ifdef PADDLE_WITH_MKLDNN
/* force_disable_gc */ false, platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place);
/* keep_kid_scopes */ false); platform::RegisterModelLayout(ctx->ops_, dev_place);
#endif
}
exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, false);
AssignLocalGradientToParentScope( AssignLocalGradientToParentScope(
dev_place, cur_scope, scope, inside_grads, outside_grads, inputs); dev_place, cur_scope, scope, inside_grads, outside_grads, inputs);
...@@ -174,6 +187,10 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -174,6 +187,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
AssignZeroToParentScope(dev_place, scope, inputs, outside_grads); AssignZeroToParentScope(dev_place, scope, inputs, outside_grads);
} }
private:
mutable std::shared_ptr<Executor> exec{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr};
private: private:
void AssignLocalGradientToParentScope( void AssignLocalGradientToParentScope(
const platform::Place &place, const platform::Place &place,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册