未验证 提交 8e7c3789 编写于 作者: Z zhupengyang 提交者: GitHub

cache scope in while (#52628)

上级 cea62c00
......@@ -22,6 +22,13 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
PADDLE_DEFINE_EXPORTED_bool(
cache_inference_while_scope,
false,
"Cache the scope of the while op to avoid repeated creation of the scope "
"for each iteration and improve inference performance.");
namespace paddle {
namespace framework {
class InferShapeContext;
......@@ -257,14 +264,23 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}
} else {
auto &current_scope = scope.NewScope();
BuildScopeForControlFlowOp(*core_, *block, &current_scope);
core_->reset_scope(&current_scope);
framework::Scope *current_scope = nullptr;
if (!FLAGS_cache_inference_while_scope) {
current_scope = &(scope.NewScope());
BuildScopeForControlFlowOp(*core_, *block, current_scope);
core_->reset_scope(current_scope);
} else {
if (cached_inference_scope_ == nullptr) {
cached_inference_scope_ = &(scope.NewScope());
BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_);
core_->reset_scope(cached_inference_scope_);
}
current_scope = cached_inference_scope_;
}
while (cond_data) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
for (auto &name : current_scope->LocalVarNames()) {
auto *var = current_scope->Var(name);
if (var->IsType<phi::DenseTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<phi::DenseTensor>();
......@@ -283,7 +299,9 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}
scope.DeleteScope(&current_scope);
if (!FLAGS_cache_inference_while_scope) {
scope.DeleteScope(current_scope);
}
}
}
......@@ -291,6 +309,7 @@ class WhileOp : public framework::OperatorBase {
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
mutable framework::Scope *cached_inference_scope_{nullptr};
};
class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册