未验证 提交 8e7c3789 编写于 作者: Z zhupengyang 提交者: GitHub

cache scope in while (#52628)

上级 cea62c00
...@@ -22,6 +22,13 @@ ...@@ -22,6 +22,13 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
PADDLE_DEFINE_EXPORTED_bool(
cache_inference_while_scope,
false,
"Cache the scope of the while op to avoid repeated creation of the scope "
"for each iteration and improve inference performance.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext; class InferShapeContext;
...@@ -257,14 +264,23 @@ class WhileOp : public framework::OperatorBase { ...@@ -257,14 +264,23 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>()); scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
} }
} else { } else {
auto &current_scope = scope.NewScope(); framework::Scope *current_scope = nullptr;
if (!FLAGS_cache_inference_while_scope) {
BuildScopeForControlFlowOp(*core_, *block, &current_scope); current_scope = &(scope.NewScope());
core_->reset_scope(&current_scope); BuildScopeForControlFlowOp(*core_, *block, current_scope);
core_->reset_scope(current_scope);
} else {
if (cached_inference_scope_ == nullptr) {
cached_inference_scope_ = &(scope.NewScope());
BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_);
core_->reset_scope(cached_inference_scope_);
}
current_scope = cached_inference_scope_;
}
while (cond_data) { while (cond_data) {
for (auto &name : current_scope.LocalVarNames()) { for (auto &name : current_scope->LocalVarNames()) {
auto *var = current_scope.Var(name); auto *var = current_scope->Var(name);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
// Clear all lod information for all lod_tensors. // Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<phi::DenseTensor>(); auto *t = var->GetMutable<phi::DenseTensor>();
...@@ -283,7 +299,9 @@ class WhileOp : public framework::OperatorBase { ...@@ -283,7 +299,9 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>()); scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
} }
scope.DeleteScope(&current_scope); if (!FLAGS_cache_inference_while_scope) {
scope.DeleteScope(current_scope);
}
} }
} }
...@@ -291,6 +309,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -291,6 +309,7 @@ class WhileOp : public framework::OperatorBase {
mutable std::shared_ptr<framework::Executor> executor_{nullptr}; mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr}; mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr}; mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
mutable framework::Scope *cached_inference_scope_{nullptr};
}; };
class WhileOpMaker : public framework::OpProtoAndCheckerMaker { class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册