未验证 提交 6dbfbfa5 编写于 作者: K kangguangli 提交者: GitHub

[Control Flow] replace executor in while op with InterpreterCore (#47573)

* fix:add no support for cuda_arch<700

* replace Executor in while op with InterpreterCore

* cache InterpreterCore as the member of WhileOp

* fix bug: tensor place changed because of assign op in while loop

* refine code

* refine code

* refine code

* hot fix

* fix compile

* merge develop

* follow comments

* add log for test

* remove LoDTensor

* set flag control_flow_use_new_executor false
Co-authored-by: Nfengshuai <fengshuai03@baidu.com>
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 9d4b4be3
...@@ -638,9 +638,20 @@ void InterpreterCore::Convert( ...@@ -638,9 +638,20 @@ void InterpreterCore::Convert(
if (var_desc && ins.count(item.first) && if (var_desc && ins.count(item.first) &&
!info.IsInArgBufferNeeded(var_desc->Name())) { !info.IsInArgBufferNeeded(var_desc->Name())) {
continue; continue;
} else if (!block_.HasVar(var_scope_.GetNameById(id))) { }
VLOG(10) << "[gc_check_inputs] skip gc: " // skip when this var is not in block and not a data_transferred var,
<< var_scope_.GetNameById(id); // which means this var is managed by other block
const auto& var_name = var_scope_.GetNameById(id);
bool not_owned = !block_.HasVar(var_name);
const auto& transferred_vars = var_scope_.DataTransferAddedVars();
bool not_transferred =
std::all_of(transferred_vars.begin(),
transferred_vars.end(),
[&](const std::pair<std::string, int>& elem) {
return elem.first != var_name;
});
if (not_owned && not_transferred) {
VLOG(10) << "[gc_check_inputs] skip gc: " << var_name;
continue; continue;
} }
gc_check_vars.insert(id); gc_check_vars.insert(id);
...@@ -759,7 +770,7 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) { ...@@ -759,7 +770,7 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) {
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope(); : var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
SetDeviceId(place); SetDeviceId(place);
...@@ -873,7 +884,7 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) { ...@@ -873,7 +884,7 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) {
VLOG(4) << "Check nan/inf"; VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf( framework::details::CheckOpHasNanOrInf(
*op, *op,
*local_scope_, *local_scope,
place); // TODO(xiongkun03) change it to inner scope. place); // TODO(xiongkun03) change it to inner scope.
} }
} }
......
...@@ -3,7 +3,8 @@ if(WITH_UNITY_BUILD) ...@@ -3,7 +3,8 @@ if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/controlflow. # Load Unity Build rules for operators in paddle/fluid/operators/controlflow.
include(unity_build_rule.cmake) include(unity_build_rule.cmake)
endif() endif()
register_operators(EXCLUDES conditional_block_op DEPS naive_executor) register_operators(EXCLUDES conditional_block_op DEPS naive_executor
standalone_executor)
cc_library( cc_library(
conditional_block_op conditional_block_op
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/operators/assign_op.h" #include "paddle/fluid/operators/assign_op.h"
#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -39,43 +40,6 @@ using ExecutorPrepareContext = framework::ExecutorPrepareContext; ...@@ -39,43 +40,6 @@ using ExecutorPrepareContext = framework::ExecutorPrepareContext;
using InterpreterCore = framework::InterpreterCore; using InterpreterCore = framework::InterpreterCore;
namespace details {
static void BuildScopeForConditionalBlockOp(
const paddle::framework::InterpreterCore &interpreter_core,
const paddle::framework::BlockDesc &block,
paddle::framework::Scope *scope) {
for (auto &var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "start:" << var_name;
if (var_desc->Persistable()) {
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "Don't process persistent: " << var_name;
} else {
auto *ptr = scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "Not Found locally and created: " << var_name;
}
}
auto &data_transfer_added_vars =
interpreter_core.GetVariableScope()->DataTransferAddedVars();
for (size_t i = 0; i < data_transfer_added_vars.size(); i++) {
auto *ptr = scope->Var(data_transfer_added_vars[i].first);
InitializeVariable(ptr,
static_cast<paddle::framework::proto::VarType::Type>(
data_transfer_added_vars[i].second));
VLOG(10) << "[BuildScopeForConditionalBlockOp]"
<< "Initialize Transfer Added Variable "
<< data_transfer_added_vars[i].first;
}
}
} // namespace details
class ConditionalBlockOp : public ConditionalOp { class ConditionalBlockOp : public ConditionalOp {
public: public:
ConditionalBlockOp(const std::string &type, ConditionalBlockOp(const std::string &type,
...@@ -141,39 +105,41 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -141,39 +105,41 @@ class ConditionalBlockOp : public ConditionalOp {
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars); Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
if (FLAGS_control_flow_use_new_executor) { if (FLAGS_control_flow_use_new_executor) {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end()); LOG_FIRST_N(INFO, 1)
<< "[ControlFlow][ConditionalBlock] New Executor is Running.";
if (!core || !platform::is_same_place(core->GetPlace(), dev_place)) { if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
VLOG(10) << "[interpreterCore cache]" << core.get(); std::set<std::string> skip_gc_vars(skip_vars.begin(),
VLOG_IF(10, core) skip_vars.end());
<< platform::is_same_place(core->GetPlace(), dev_place); VLOG(10) << "[interpreterCore cache]" << core_.get();
core.reset(new InterpreterCore(dev_place, VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place);
core_.reset(new InterpreterCore(dev_place,
*block, *block,
skip_gc_vars, skip_gc_vars,
&cur_scope, &cur_scope,
/* used_for_jit */ false, /* used_for_jit */ false,
/* used_for_control_flow_op */ true)); /* used_for_control_flow_op */ true));
VLOG(10) << "[interpreterCore cache]" VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core; << "new created:" << core_;
} else { } else {
details::BuildScopeForConditionalBlockOp(*core, *block, &cur_scope); BuildScopeForControlFlowOp(*core_, *block, &cur_scope);
core->reset_scope(&cur_scope); core_->reset_scope(&cur_scope);
} }
core->Run({}, false); core_->Run({}, false);
} else { } else {
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) { if (!exec_ || !platform::is_same_place(exec_->GetPlace(), dev_place)) {
auto &pdesc = *block->Program(); auto &pdesc = *block->Program();
exec.reset(new Executor(dev_place)); exec_.reset(new Executor(dev_place));
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) exec_->EnableMKLDNN(pdesc);
ctx = exec->Prepare(pdesc, block->ID(), skip_vars, false); ctx_ = exec_->Prepare(pdesc, block->ID(), skip_vars, false);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); platform::AttachPointerHashToMKLDNNKey(exec_.get(), dev_place);
platform::RegisterModelLayout(ctx->ops_, dev_place); platform::RegisterModelLayout(ctx_->ops_, dev_place);
#endif #endif
} }
exec->RunPreparedContext(ctx.get(), exec_->RunPreparedContext(ctx_.get(),
&cur_scope, &cur_scope,
/* create_local_scope */ false, /* create_local_scope */ false,
/* create_vars */ true, /* create_vars */ true,
...@@ -183,9 +149,9 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -183,9 +149,9 @@ class ConditionalBlockOp : public ConditionalOp {
} }
private: private:
mutable std::shared_ptr<Executor> exec{nullptr}; mutable std::shared_ptr<Executor> exec_{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr}; mutable std::unique_ptr<ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<InterpreterCore> core{nullptr}; mutable std::shared_ptr<InterpreterCore> core_{nullptr};
}; };
class ConditionalBlockInferShape : public framework::InferShapeBase { class ConditionalBlockInferShape : public framework::InferShapeBase {
...@@ -251,39 +217,40 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -251,39 +217,40 @@ class ConditionalBlockGradOp : public ConditionalOp {
<< ", scope = " << &cur_scope; << ", scope = " << &cur_scope;
if (FLAGS_control_flow_use_new_executor) { if (FLAGS_control_flow_use_new_executor) {
LOG_FIRST_N(INFO, 1)
<< "[ControlFlow][ConditionalGradBlock] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
VLOG(10) << "[interpreterCore cache]" << core_.get();
VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place);
std::set<std::string> skip_gc_vars(inside_grads.begin(), std::set<std::string> skip_gc_vars(inside_grads.begin(),
inside_grads.end()); inside_grads.end());
core_.reset(new InterpreterCore(dev_place,
if (!core || !platform::is_same_place(core->GetPlace(), dev_place)) {
VLOG(10) << "[interpreterCore cache]" << core.get();
VLOG_IF(10, core)
<< platform::is_same_place(core->GetPlace(), dev_place);
core.reset(new InterpreterCore(dev_place,
*block, *block,
skip_gc_vars, skip_gc_vars,
&cur_scope, &cur_scope,
/* used_for_jit */ false, /* used_for_jit */ false,
/* used_for_control_flow_op */ true)); /* used_for_control_flow_op */ true));
VLOG(10) << "[interpreterCore cache]" VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core; << "new created:" << core_;
} else { } else {
details::BuildScopeForConditionalBlockOp(*core, *block, &cur_scope); BuildScopeForControlFlowOp(*core_, *block, &cur_scope);
core->reset_scope(&cur_scope); core_->reset_scope(&cur_scope);
} }
core->Run({}, false); core_->Run({}, false);
} else { } else {
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) { if (!exec_ || !platform::is_same_place(exec_->GetPlace(), dev_place)) {
auto &pdesc = *block->Program(); auto &pdesc = *block->Program();
exec.reset(new Executor(dev_place)); exec_.reset(new Executor(dev_place));
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) exec_->EnableMKLDNN(pdesc);
ctx = exec->Prepare(pdesc, block->ID(), inside_grads, false); ctx_ = exec_->Prepare(pdesc, block->ID(), inside_grads, false);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); platform::AttachPointerHashToMKLDNNKey(exec_.get(), dev_place);
platform::RegisterModelLayout(ctx->ops_, dev_place); platform::RegisterModelLayout(ctx_->ops_, dev_place);
#endif #endif
} }
exec->RunPreparedContext(ctx.get(), exec_->RunPreparedContext(ctx_.get(),
&cur_scope, &cur_scope,
/* create_local_scope */ false, /* create_local_scope */ false,
/* create_vars */ true, /* create_vars */ true,
...@@ -299,9 +266,9 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -299,9 +266,9 @@ class ConditionalBlockGradOp : public ConditionalOp {
} }
private: private:
mutable std::shared_ptr<Executor> exec{nullptr}; mutable std::shared_ptr<Executor> exec_{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr}; mutable std::unique_ptr<ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<InterpreterCore> core{nullptr}; mutable std::shared_ptr<InterpreterCore> core_{nullptr};
private: private:
void AssignLocalGradientToParentScope( void AssignLocalGradientToParentScope(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
namespace paddle {
namespace operators {
static void BuildScopeForControlFlowOp(
const framework::InterpreterCore &interpreter_core,
const framework::BlockDesc &block,
framework::Scope *scope) {
for (auto &var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
VLOG(5) << "[BuildScopeForControlFlowOp]"
<< "start:" << var_name;
if (var_desc->Persistable()) {
VLOG(5) << "[BuildScopeForControlFlowOp]"
<< "Don't process persistent: " << var_name;
} else {
auto *ptr = scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(5) << "[BuildScopeForControlFlowOp]"
<< "Not Found locally and created: " << var_name;
}
}
auto &data_transfer_added_vars =
interpreter_core.GetVariableScope()->DataTransferAddedVars();
for (size_t i = 0; i < data_transfer_added_vars.size(); i++) {
auto *ptr = scope->Var(data_transfer_added_vars[i].first);
InitializeVariable(ptr,
static_cast<paddle::framework::proto::VarType::Type>(
data_transfer_added_vars[i].second));
VLOG(5) << "[BuildScopeForControlFlowOp]"
<< "Initialize Transfer Added Variable "
<< data_transfer_added_vars[i].first;
}
}
} // namespace operators
} // namespace paddle
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -44,6 +46,41 @@ static std::string GetSkipEagerDeletionVarsDebugString( ...@@ -44,6 +46,41 @@ static std::string GetSkipEagerDeletionVarsDebugString(
} }
return str; return str;
} }
static void TransferVariablePlace(const framework::Scope *scope,
const std::string &var_name,
const phi::Place &dst_place,
const platform::DeviceContext &dev_ctx) {
framework::Variable *var = scope->FindVar(var_name);
if (var == nullptr) {
VLOG(4) << "[TransferVariablePlace]"
<< "lost in_var: " << var_name;
return;
}
if (var->Type() != framework::proto::VarType::LOD_TENSOR) {
VLOG(10) << "[TransferVariablePlace]" << var_name << " type changed:"
<< framework::TransToPhiDataType(
framework::ToVarType(var->Type()));
return;
}
phi::DenseTensor *t = var->GetMutable<phi::DenseTensor>();
if (t->place() == dst_place) {
VLOG(10) << "[TransferVariablePlace]"
<< "no need transfer: " << var_name;
return;
}
phi::DenseTensor *new_t = new phi::DenseTensor;
framework::TensorCopy(*t, dst_place, new_t);
dev_ctx.Wait();
t->set_meta(new_t->meta());
t->ResetHolder(new_t->Holder());
VLOG(4) << "[TransferVariablePlace]" << var_name
<< " place: " << new_t->place();
}
} // namespace } // namespace
class WhileOp : public framework::OperatorBase { class WhileOp : public framework::OperatorBase {
...@@ -77,9 +114,12 @@ class WhileOp : public framework::OperatorBase { ...@@ -77,9 +114,12 @@ class WhileOp : public framework::OperatorBase {
// Executors (executors declared inside control ops) // Executors (executors declared inside control ops)
platform::DontClearMKLDNNCache(dev_place); platform::DontClearMKLDNNCache(dev_place);
#endif #endif
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto *program = block->Program(); auto *program = block->Program();
bool is_test = Attr<bool>("is_test"); bool is_test = Attr<bool>("is_test");
...@@ -134,7 +174,53 @@ class WhileOp : public framework::OperatorBase { ...@@ -134,7 +174,53 @@ class WhileOp : public framework::OperatorBase {
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); // note(lvyongkang): The assign op in while loop may change the place of
// variable. However, InterpreterCore fix the kernel of every ops during its
// first run. A cpu tensor may become gpu tensor after first run. This will
// lead to segmetation fault when it's used in a cpu kernel. Here we record
// the place of every inputs and restore their place after
// InterpreterCore.run().
std::map<std::string, phi::Place> input_var_original_places;
for (const auto &in_name : Inputs(kX)) {
framework::Variable *var = scope.FindVar(in_name);
if (var == nullptr) {
VLOG(4) << "[while op]"
<< "input not found:" << in_name;
}
if (var->Type() == framework::proto::VarType::LOD_TENSOR) {
input_var_original_places[in_name] =
(var->Get<phi::DenseTensor>()).place();
} else {
VLOG(10) << "[while op]"
<< "skip backup input " << in_name << " type:"
<< framework::TransToPhiDataType(
framework::ToVarType(var->Type()));
}
}
if (FLAGS_control_flow_use_new_executor) {
LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore
core_.reset(new framework::InterpreterCore(
dev_place,
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
}
} else {
if (!executor_ ||
!platform::is_same_place(executor_->GetPlace(), dev_place)) {
executor_.reset(new framework::Executor(dev_place));
ctx_ = executor_->Prepare(*program, block->ID(), skip_vars);
}
}
if (!is_test) { if (!is_test) {
while (cond_data) { while (cond_data) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
...@@ -158,8 +244,23 @@ class WhileOp : public framework::OperatorBase { ...@@ -158,8 +244,23 @@ class WhileOp : public framework::OperatorBase {
} }
} }
} }
executor.RunPreparedContext( if (FLAGS_control_flow_use_new_executor) {
ctx.get(), &current_scope, false, true, true); BuildScopeForControlFlowOp(*core_, *block, &current_scope);
core_->reset_scope(&current_scope);
core_->Run({}, false);
// restore inputs place
for (const auto &n : input_var_original_places) {
const std::string &in_name = n.first;
const phi::Place &original_place = n.second;
// input vars exist in `scope` not `current_scope`
TransferVariablePlace(&scope, in_name, original_place, dev_ctx);
}
} else {
executor_->RunPreparedContext(
ctx_.get(), &current_scope, false, true, true);
}
for (auto &var_rename : rename_vars) { for (auto &var_rename : rename_vars) {
std::string input_var_name = std::string input_var_name =
...@@ -171,7 +272,14 @@ class WhileOp : public framework::OperatorBase { ...@@ -171,7 +272,14 @@ class WhileOp : public framework::OperatorBase {
} }
} else { } else {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
executor.CreateVariables(*program, &current_scope, block->ID());
if (FLAGS_control_flow_use_new_executor) {
BuildScopeForControlFlowOp(*core_, *block, &current_scope);
core_->reset_scope(&current_scope);
} else {
executor_->CreateVariables(*program, &current_scope, block->ID());
}
while (cond_data) { while (cond_data) {
for (auto &name : current_scope.LocalVarNames()) { for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name); auto *var = current_scope.Var(name);
...@@ -186,14 +294,25 @@ class WhileOp : public framework::OperatorBase { ...@@ -186,14 +294,25 @@ class WhileOp : public framework::OperatorBase {
t->clear(); t->clear();
} }
} }
executor.RunPreparedContext(
ctx.get(), &current_scope, false, false, false); if (FLAGS_control_flow_use_new_executor) {
core_->Run({}, false);
} else {
executor_->RunPreparedContext(
ctx_.get(), &current_scope, false, false, false);
}
cond_data = GetCondData( cond_data = GetCondData(
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>()); scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
} }
scope.DeleteScope(&current_scope); scope.DeleteScope(&current_scope);
} }
} }
private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
}; };
class WhileOpMaker : public framework::OpProtoAndCheckerMaker { class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -245,13 +364,12 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -245,13 +364,12 @@ class WhileGradOp : public framework::OperatorBase {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
...@@ -271,6 +389,29 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -271,6 +389,29 @@ class WhileGradOp : public framework::OperatorBase {
outside_og_names.size(), outside_og_names.size(),
inside_og_names.size())); inside_og_names.size()));
if (FLAGS_control_flow_use_new_executor) {
LOG_FIRST_N(INFO, 1)
<< "[ControlFlow][WhileGradOp] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore
core_.reset(new framework::InterpreterCore(
dev_place,
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
}
} else {
if (!executor_ ||
!platform::is_same_place(executor_->GetPlace(), dev_place)) {
executor_.reset(new framework::Executor(dev_place));
ctx_ = executor_->Prepare(*program, block->ID(), skip_vars);
}
}
for (auto cur_scope_iter = step_scopes->rbegin(); for (auto cur_scope_iter = step_scopes->rbegin();
cur_scope_iter != step_scopes->rend(); cur_scope_iter != step_scopes->rend();
++cur_scope_iter) { ++cur_scope_iter) {
...@@ -330,8 +471,15 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -330,8 +471,15 @@ class WhileGradOp : public framework::OperatorBase {
"WhileGradOp.")); "WhileGradOp."));
} }
} }
executor.RunPreparedContext(
ctx.get(), *cur_scope_iter, false, true, true); if (FLAGS_control_flow_use_new_executor) {
BuildScopeForControlFlowOp(*core_, *block, *cur_scope_iter);
core_->reset_scope(*cur_scope_iter);
core_->Run({}, false);
} else {
executor_->RunPreparedContext(
ctx_.get(), *cur_scope_iter, false, true, true);
}
// The Outputs(kXGRAD) contains the names of the gradient of parameters // The Outputs(kXGRAD) contains the names of the gradient of parameters
// and inputs. // and inputs.
...@@ -446,6 +594,11 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -446,6 +594,11 @@ class WhileGradOp : public framework::OperatorBase {
} }
step_scopes->clear(); step_scopes->clear();
} }
private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
}; };
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册