未验证 提交 3b219e5e 编写于 作者: K kangguangli 提交者: GitHub

[ControlFlow] replace executor in run method of control flow ops with standalone_executor (#45696)

* replace executor in conditional_block_op.run with standalone_executor

* add block_id as the argument of standalone executor's method run; add print for program

* fix scope bug about conditional block op

* fix bug: unnecessary return of fetch value

* fix typo

* fix: quantization will set variable persistable, and these variables must exist in global scope

* add interpretercore cache for conditional block op but not activate in default

* fix bug: local scope reuse for conditional block op

* reset scope when conditional block op runs

* fix typo

* fix typo and code style

* add build scope for conditional block op

* add skip for transfer_layout kernel

* refind code

* fix reset_scope

* fix reset_scope

* refine code

* refine code

* refine code

1. remove flag use in conditional_block_op
2. pass execution_config to BuildOpFuncList instead of individual parameter

* refine code

* remove the use of FLAGS_control_flow_use_new_executor_cache

* change FLAGS_control_flow_use_new_executor to false
上级 bb6356e8
...@@ -26,6 +26,7 @@ namespace interpreter { ...@@ -26,6 +26,7 @@ namespace interpreter {
struct ExecutionConfig { struct ExecutionConfig {
bool used_for_jit{false}; bool used_for_jit{false};
bool create_local_scope{true}; bool create_local_scope{true};
bool used_for_control_flow_op{false};
size_t host_num_threads; size_t host_num_threads;
size_t deivce_num_threads; size_t deivce_num_threads;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/memory/stats.h" #include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
...@@ -227,7 +228,14 @@ void BuildVariableScope(const framework::BlockDesc& block, ...@@ -227,7 +228,14 @@ void BuildVariableScope(const framework::BlockDesc& block,
} }
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name); // In principle, we should put all trainable parameters in global scope,
// which means the root of the scope tree. Some cases like quantization
// will look up these parameters in global scope.
const Scope* ancestor_scope = inner_scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name; VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right, // NOTE(zhiqiu): if var exists in scope and the type is right,
...@@ -291,7 +299,7 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap( ...@@ -291,7 +299,7 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
const VariableNameMap& var_name_map, const VariableNameMap& var_name_map,
VariableScope* var_scope, VariableScope* var_scope,
Scope* local_scope, Scope* local_scope,
bool allow_var_not_in_program = false, bool find_var_recursively = false,
bool allow_var_not_in_scope = false) { bool allow_var_not_in_scope = false) {
VariableValueMap name2var; VariableValueMap name2var;
VariableIdMap name2id; VariableIdMap name2id;
...@@ -301,8 +309,10 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap( ...@@ -301,8 +309,10 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
vars.reserve(item.second.size()); vars.reserve(item.second.size());
for (auto& var_name : item.second) { for (auto& var_name : item.second) {
auto* var = local_scope->FindVar(var_name);
if (!var_scope->HasVar(var_name)) { if (!var_scope->HasVar(var_name)) {
if (allow_var_not_in_program && local_scope->FindVar(var_name)) { if (find_var_recursively && var) {
VLOG(3) << "Add " << var_name << " to var_scope"; VLOG(3) << "Add " << var_name << " to var_scope";
var_scope->AddVar(var_name, nullptr); var_scope->AddVar(var_name, nullptr);
} else if (allow_var_not_in_scope) { } else if (allow_var_not_in_scope) {
...@@ -310,7 +320,6 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap( ...@@ -310,7 +320,6 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
continue; continue;
} }
} }
auto* var = local_scope->FindVar(var_name);
auto var_id = var_scope->VarId(var_name); auto var_id = var_scope->VarId(var_name);
vars.push_back(var); vars.push_back(var);
ids.push_back(var_id); ids.push_back(var_id);
...@@ -419,8 +428,8 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -419,8 +428,8 @@ void BuildOpFuncList(const platform::Place& place,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, VariableScope* var_scope,
bool use_local_scope, const ExecutionConfig& execution_config,
bool used_for_jit) { bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope(); : var_scope->GetMutableScope();
std::vector<std::unique_ptr<OperatorBase>> std::vector<std::unique_ptr<OperatorBase>>
...@@ -428,7 +437,7 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -428,7 +437,7 @@ void BuildOpFuncList(const platform::Place& place,
// Step 1: create all ops for current block. // Step 1: create all ops for current block.
CreateAllOps(block, &ops_unique); CreateAllOps(block, &ops_unique);
if (!used_for_jit) { if (!execution_config.used_for_jit) {
// If gc is enabled and block size > 1 // If gc is enabled and block size > 1
const ProgramDesc& main_program = *block.Program(); const ProgramDesc& main_program = *block.Program();
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
...@@ -479,14 +488,18 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -479,14 +488,18 @@ void BuildOpFuncList(const platform::Place& place,
bool allow_var_not_in_program = ops_with_var_not_in_program.count(op_type); bool allow_var_not_in_program = ops_with_var_not_in_program.count(op_type);
bool allow_var_not_in_scope = ops_with_var_not_in_scope.count(op_type); bool allow_var_not_in_scope = ops_with_var_not_in_scope.count(op_type);
// ops in the control flow block may not find its inputs or outputs
// in VariableScope of the sub-block, so we need search it in parent scope.
framework::VariableNameMap& input_name_map = op->Inputs(); framework::VariableNameMap& input_name_map = op->Inputs();
VariableValueMap ins_map; VariableValueMap ins_map;
VariableIdMap ins_name2id; VariableIdMap ins_name2id;
std::tie(ins_map, ins_name2id) = BuildVariableMap(input_name_map, std::tie(ins_map, ins_name2id) = BuildVariableMap(
var_scope, input_name_map,
local_scope, var_scope,
allow_var_not_in_program, local_scope,
allow_var_not_in_scope); execution_config.used_for_control_flow_op || allow_var_not_in_program,
allow_var_not_in_scope);
framework::VariableNameMap& output_name_map = op->Outputs(); framework::VariableNameMap& output_name_map = op->Outputs();
VariableValueMap outs_map; VariableValueMap outs_map;
...@@ -495,7 +508,7 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -495,7 +508,7 @@ void BuildOpFuncList(const platform::Place& place,
BuildVariableMap(output_name_map, BuildVariableMap(output_name_map,
var_scope, var_scope,
local_scope, local_scope,
/*allow_var_not_in_program=*/false, execution_config.used_for_control_flow_op,
allow_var_not_in_scope); allow_var_not_in_scope);
// step 1: build OpFuncNode // step 1: build OpFuncNode
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
...@@ -75,8 +76,8 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -75,8 +76,8 @@ void BuildOpFuncList(const platform::Place& place,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* scope, VariableScope* scope,
bool use_local_scope = true, const ExecutionConfig& execution_config,
bool used_for_jit = false); bool use_local_scope = true);
void AddFetch(const std::vector<std::string>& fetch_names, void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block); framework::BlockDesc* block);
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <unordered_set> #include <unordered_set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
...@@ -47,6 +49,9 @@ PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, ...@@ -47,6 +49,9 @@ PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope,
true, true,
"Use local_scope in new executor(especially used " "Use local_scope in new executor(especially used "
"in UT), can turn off for better performance"); "in UT), can turn off for better performance");
PADDLE_DEFINE_EXPORTED_bool(control_flow_use_new_executor,
false,
"Use new executor in control flow op");
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
...@@ -107,7 +112,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -107,7 +112,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
framework::Scope* scope, framework::Scope* scope,
bool used_for_jit) bool used_for_jit,
bool used_for_control_flow_op)
: place_(place), : place_(place),
block_(block), block_(block),
execution_config_(place, block.OpSize()), execution_config_(place, block.OpSize()),
...@@ -119,8 +125,10 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -119,8 +125,10 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
execution_config_.used_for_jit = used_for_jit; execution_config_.used_for_jit = used_for_jit;
execution_config_.create_local_scope = execution_config_.used_for_control_flow_op = used_for_control_flow_op;
!used_for_jit && FLAGS_new_executor_use_local_scope; execution_config_.create_local_scope = !used_for_jit &&
FLAGS_new_executor_use_local_scope &&
!used_for_control_flow_op;
execution_config_.skip_gc_vars = skip_gc_vars; execution_config_.skip_gc_vars = skip_gc_vars;
execution_config_.Log(/*log_level=*/8); execution_config_.Log(/*log_level=*/8);
...@@ -224,7 +232,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -224,7 +232,7 @@ paddle::framework::FetchList InterpreterCore::Run(
} }
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) { const std::vector<std::string>& feed_names, bool need_fetch) {
SetDeviceId(place_); SetDeviceId(place_);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -243,12 +251,12 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -243,12 +251,12 @@ paddle::framework::FetchList InterpreterCore::Run(
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
HasLocalScope(), execution_config_,
execution_config_.used_for_jit); HasLocalScope());
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
is_build_ = true;
} else { } else {
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
...@@ -281,7 +289,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -281,7 +289,7 @@ paddle::framework::FetchList InterpreterCore::Run(
Scope* inner_scope = Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var) { if (fetch_var && need_fetch) {
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else { } else {
return {}; return {};
...@@ -311,9 +319,18 @@ void InterpreterCore::reset_scope(Scope* new_scope) { ...@@ -311,9 +319,18 @@ void InterpreterCore::reset_scope(Scope* new_scope) {
var_scope_.SetScope(new_scope); var_scope_.SetScope(new_scope);
auto& var_list = var_scope_.MutableVarList(); auto& var_list = var_scope_.MutableVarList();
for (size_t i = 0; i < var_list.size(); i++) { for (size_t i = 0; i < var_list.size(); i++) {
var_list[i] = new_scope->FindVar(var_scope_.GetNameById(i)); const auto& var_name = var_scope_.GetNameById(i);
var_list[i] = new_scope->FindVar(var_name);
} }
for (size_t i = 0; i < vec_instruction_.size(); ++i) { // The index should assured valid, cause the InterpreterCore may not be fully
// built, but was still cached and used. For example, see unit test
// `test_assert.py`, it may exit before `InterpreterCore::Convert`, but still
// was cached and used by later tests.
for (size_t i = 0; i < std::min(refs_.size(), var_list.size()); i++) {
refs_[i]->ResetVariable(var_list[i]);
}
for (size_t i = 0; i < vec_instruction_.size(); i++) {
BuildAndCacheInstructionCtx(&vec_instruction_[i]); BuildAndCacheInstructionCtx(&vec_instruction_[i]);
} }
} }
...@@ -540,6 +557,10 @@ void InterpreterCore::Convert( ...@@ -540,6 +557,10 @@ void InterpreterCore::Convert(
if (var_desc && ins.count(item.first) && if (var_desc && ins.count(item.first) &&
!info.IsInArgBufferNeeded(var_desc->Name())) { !info.IsInArgBufferNeeded(var_desc->Name())) {
continue; continue;
} else if (!block_.HasVar(var_scope_.GetNameById(id))) {
VLOG(10) << "[gc_check_inputs] skip gc: "
<< var_scope_.GetNameById(id);
continue;
} }
gc_check_vars.insert(id); gc_check_vars.insert(id);
} }
...@@ -661,9 +682,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -661,9 +682,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place)) { if (platform::is_npu_place(place)) {
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the
// values, but only through special `float_status` to checks whether // variable values, but only through special `float_status` to checks
// the operation is overflow. More about `float_status`, see: // whether the operation is overflow. More about `float_status`, see:
// https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue // https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place); framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place);
...@@ -734,7 +755,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -734,7 +755,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
} }
VLOG(4) << "End run " << place << " " << op->DebugStringEx(local_scope_); VLOG(4) << "End run " << place << " " << op->DebugStringEx(local_scope);
if (!instr_node.InplaceBackMap().empty()) { if (!instr_node.InplaceBackMap().empty()) {
platform::RecordEvent inplaceback_event( platform::RecordEvent inplaceback_event(
...@@ -965,9 +986,9 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -965,9 +986,9 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
memory::RecordStream(allocation, stream); memory::RecordStream(allocation, stream);
} else if (platform::is_cuda_pinned_place(place)) { } else if (platform::is_cuda_pinned_place(place)) {
// TODO(Ruibiao): Here should do something to make sure that the tensor is // TODO(Ruibiao): Here should do something to make sure that the tensor
// not freed until the H2D copies done. However, simplely launch a CUDA // is not freed until the H2D copies done. However, simplely launch a
// runtime callback to the H2D stream may lead a high performance // CUDA runtime callback to the H2D stream may lead a high performance
// overhead. As all the cases we meet in H2D are copies from CPUPlace at // overhead. As all the cases we meet in H2D are copies from CPUPlace at
// present, we just log a WARNING here. A better design is required. // present, we just log a WARNING here. A better design is required.
LOG(WARNING) << "Copy data from a CUDAPinned tensor in an asynchronous " LOG(WARNING) << "Copy data from a CUDAPinned tensor in an asynchronous "
...@@ -984,8 +1005,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -984,8 +1005,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
* instr.GCCheckVars. * instr.GCCheckVars.
* 2. The stream which initializes this tensor is different from the stream * 2. The stream which initializes this tensor is different from the stream
* which the instruction run in. * which the instruction run in.
* 3. The tensor is the instruction's input, cause we assume that instruction * 3. The tensor is the instruction's input, cause we assume that
* will initialize all output tensors with its running stream. * instruction will initialize all output tensors with its running stream.
* 4. In the OP function of this instruction, the tensor is an input of a * 4. In the OP function of this instruction, the tensor is an input of a
* async CUDA kernel. * async CUDA kernel.
* *
...@@ -995,8 +1016,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -995,8 +1016,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
* initialized this tensor has less time overhead. Conversely, it may take * initialized this tensor has less time overhead. Conversely, it may take
* more time if we try to extract those cross-stream input vars from * more time if we try to extract those cross-stream input vars from
* instr.GCCheckVars. * instr.GCCheckVars.
* 2. Now the instruction has no idea of which vars involving async running in * 2. Now the instruction has no idea of which vars involving async running
* OP function, and thus we can not recognize condition 4. It should be * in OP function, and thus we can not recognize condition 4. It should be
* supported later. * supported later.
*/ */
for (int var_id : instr.GCCheckVars()) { for (int var_id : instr.GCCheckVars()) {
...@@ -1099,12 +1120,12 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1099,12 +1120,12 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
HasLocalScope(), execution_config_,
execution_config_.used_for_jit); HasLocalScope());
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
is_build_ = true;
} }
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
// paddle::framework::BuildOpFuncList, so we should // paddle::framework::BuildOpFuncList, so we should
......
...@@ -34,6 +34,9 @@ ...@@ -34,6 +34,9 @@
#include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
DECLARE_bool(new_executor_use_local_scope);
DECLARE_bool(control_flow_use_new_executor);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,7 +46,8 @@ class InterpreterCore { ...@@ -43,7 +46,8 @@ class InterpreterCore {
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
Scope* scope, Scope* scope,
bool used_for_jit = false); bool used_for_jit = false,
bool used_for_control_flow_op = false);
~InterpreterCore(); ~InterpreterCore();
...@@ -55,7 +59,8 @@ class InterpreterCore { ...@@ -55,7 +59,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors); const std::vector<phi::DenseTensor>& feed_tensors);
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names); paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true);
void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src); void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);
...@@ -67,6 +72,8 @@ class InterpreterCore { ...@@ -67,6 +72,8 @@ class InterpreterCore {
void reset_scope(Scope* new_scope); void reset_scope(Scope* new_scope);
const platform::Place& GetPlace() const { return place_; }
private: private:
// build graph // build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
......
...@@ -418,6 +418,7 @@ class VarRefInfo { ...@@ -418,6 +418,7 @@ class VarRefInfo {
dynamic_ref_ = static_ref_; dynamic_ref_ = static_ref_;
} }
} }
void ResetVariable(Variable* new_var) { var_ = new_var; }
bool CheckAndDecrease() { bool CheckAndDecrease() {
return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1); return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1);
} }
......
...@@ -28,8 +28,8 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -28,8 +28,8 @@ paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names, false); auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(feed_names); return core->Run(feed_names);
} }
......
...@@ -8,7 +8,7 @@ register_operators(EXCLUDES conditional_block_op DEPS naive_executor) ...@@ -8,7 +8,7 @@ register_operators(EXCLUDES conditional_block_op DEPS naive_executor)
cc_library( cc_library(
conditional_block_op conditional_block_op
SRCS conditional_block_op.cc SRCS conditional_block_op.cc
DEPS executor) DEPS standalone_executor executor)
cc_library( cc_library(
op_variant op_variant
SRCS op_variant.cc SRCS op_variant.cc
...@@ -29,7 +29,7 @@ cc_library( ...@@ -29,7 +29,7 @@ cc_library(
cc_test( cc_test(
conditional_block_op_test conditional_block_op_test
SRCS conditional_block_op_test.cc SRCS conditional_block_op_test.cc
DEPS conditional_block_op executor) DEPS conditional_block_op standalone_executor executor)
if(WITH_UNITY_BUILD) if(WITH_UNITY_BUILD)
target_link_libraries(paddle_operators_controlflow_unity conditional_block_op) target_link_libraries(paddle_operators_controlflow_unity conditional_block_op)
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/conditional_block_op.h" #include "paddle/fluid/operators/controlflow/conditional_block_op.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/operators/assign_op.h" #include "paddle/fluid/operators/assign_op.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -35,6 +37,45 @@ const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; ...@@ -35,6 +37,45 @@ const char ConditionalOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
using Executor = framework::Executor; using Executor = framework::Executor;
using ExecutorPrepareContext = framework::ExecutorPrepareContext; using ExecutorPrepareContext = framework::ExecutorPrepareContext;
using InterpreterCore = framework::InterpreterCore;
namespace details {
static void BuildScopeForConditionalBlockOp(
const paddle::framework::InterpreterCore &interpreter_core,
const paddle::framework::BlockDesc &block,
paddle::framework::Scope *scope) {
for (auto &var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "start:" << var_name;
if (var_desc->Persistable()) {
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "Don't process persistent: " << var_name;
} else {
auto *ptr = scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(5) << "[BuildScopeForConditionalBlockOp]"
<< "Not Found locally and created: " << var_name;
}
}
auto &data_transfer_added_vars =
interpreter_core.GetVariableScope()->DataTransferAddedVars();
for (size_t i = 0; i < data_transfer_added_vars.size(); i++) {
auto *ptr = scope->Var(data_transfer_added_vars[i].first);
InitializeVariable(ptr,
static_cast<paddle::framework::proto::VarType::Type>(
data_transfer_added_vars[i].second));
VLOG(10) << "[BuildScopeForConditionalBlockOp]"
<< "Initialize Transfer Added Variable "
<< data_transfer_added_vars[i].first;
}
}
} // namespace details
class ConditionalBlockOp : public ConditionalOp { class ConditionalBlockOp : public ConditionalOp {
public: public:
ConditionalBlockOp(const std::string &type, ConditionalBlockOp(const std::string &type,
...@@ -71,9 +112,20 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -71,9 +112,20 @@ class ConditionalBlockOp : public ConditionalOp {
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Expect Scope variable to be set in conditional_block_op, but " "Expect Scope variable to be set in conditional_block_op, but "
"got a null Scope variable. Please set the Scope variable.")); "got a null Scope variable. Please set the Scope variable."));
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>(); auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
scopes->resize(1);
scopes->front() = &scope.NewScope(); if (scopes->size() == 0 || !FLAGS_control_flow_use_new_executor) {
scopes->resize(1);
scopes->front() = &scope.NewScope();
}
// We need to know whether the scope we cached is still valid.
// If not, we need to create a new one.
if (scope.kids().size() == 0) {
scopes->front() = &scope.NewScope();
}
auto &cur_scope = *scopes->front(); auto &cur_scope = *scopes->front();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// (jczaja) Executor on being destroyed clears oneDNN cache and // (jczaja) Executor on being destroyed clears oneDNN cache and
...@@ -84,25 +136,56 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -84,25 +136,56 @@ class ConditionalBlockOp : public ConditionalOp {
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID() VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope; << ", scope = " << &cur_scope;
auto &skip_vars = auto &skip_vars =
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars); Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
auto &pdesc = *block->Program(); if (FLAGS_control_flow_use_new_executor) {
exec.reset(new Executor(dev_place)); std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
ctx = exec->Prepare(pdesc, block->ID(), skip_vars, false); if (!core || !platform::is_same_place(core->GetPlace(), dev_place)) {
VLOG(10) << "[interpreterCore cache]" << core.get();
VLOG_IF(10, core)
<< platform::is_same_place(core->GetPlace(), dev_place);
core.reset(new InterpreterCore(dev_place,
*block,
skip_gc_vars,
&cur_scope,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core;
} else {
details::BuildScopeForConditionalBlockOp(*core, *block, &cur_scope);
core->reset_scope(&cur_scope);
}
core->Run({}, false);
} else {
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
auto &pdesc = *block->Program();
exec.reset(new Executor(dev_place));
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
ctx = exec->Prepare(pdesc, block->ID(), skip_vars, false);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place);
platform::RegisterModelLayout(ctx->ops_, dev_place); platform::RegisterModelLayout(ctx->ops_, dev_place);
#endif #endif
}
exec->RunPreparedContext(ctx.get(),
&cur_scope,
/* create_local_scope */ false,
/* create_vars */ true,
/* keep_kids */ true);
} }
exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, true);
} }
} }
private: private:
mutable std::shared_ptr<Executor> exec{nullptr}; mutable std::shared_ptr<Executor> exec{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr}; mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr};
mutable std::shared_ptr<InterpreterCore> core{nullptr};
}; };
class ConditionalBlockInferShape : public framework::InferShapeBase { class ConditionalBlockInferShape : public framework::InferShapeBase {
...@@ -161,23 +244,51 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -161,23 +244,51 @@ class ConditionalBlockGradOp : public ConditionalOp {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expect Scope variable contains at least 1 scope, but got: %d", "Expect Scope variable contains at least 1 scope, but got: %d",
scopes.size())); scopes.size()));
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *(scopes[0]);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional Grad block.idx = " << block->ID() VLOG(3) << "Conditional Grad block.idx = " << block->ID()
<< ", scope = " << &cur_scope; << ", scope = " << &cur_scope;
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
auto &pdesc = *block->Program(); if (FLAGS_control_flow_use_new_executor) {
exec.reset(new Executor(dev_place)); std::set<std::string> skip_gc_vars(inside_grads.begin(),
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc); inside_grads.end());
ctx = exec->Prepare(pdesc, block->ID(), inside_grads, false);
if (!core || !platform::is_same_place(core->GetPlace(), dev_place)) {
VLOG(10) << "[interpreterCore cache]" << core.get();
VLOG_IF(10, core)
<< platform::is_same_place(core->GetPlace(), dev_place);
core.reset(new InterpreterCore(dev_place,
*block,
skip_gc_vars,
&cur_scope,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core;
} else {
details::BuildScopeForConditionalBlockOp(*core, *block, &cur_scope);
core->reset_scope(&cur_scope);
}
core->Run({}, false);
} else {
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
auto &pdesc = *block->Program();
exec.reset(new Executor(dev_place));
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
ctx = exec->Prepare(pdesc, block->ID(), inside_grads, false);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place); platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place);
platform::RegisterModelLayout(ctx->ops_, dev_place); platform::RegisterModelLayout(ctx->ops_, dev_place);
#endif #endif
}
exec->RunPreparedContext(ctx.get(),
&cur_scope,
/* create_local_scope */ false,
/* create_vars */ true,
/* keep_kids */ true);
} }
exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, false);
AssignLocalGradientToParentScope( AssignLocalGradientToParentScope(
dev_place, cur_scope, scope, inside_grads, outside_grads, inputs); dev_place, cur_scope, scope, inside_grads, outside_grads, inputs);
...@@ -190,6 +301,7 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -190,6 +301,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
private: private:
mutable std::shared_ptr<Executor> exec{nullptr}; mutable std::shared_ptr<Executor> exec{nullptr};
mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr}; mutable std::unique_ptr<ExecutorPrepareContext> ctx{nullptr};
mutable std::shared_ptr<InterpreterCore> core{nullptr};
private: private:
void AssignLocalGradientToParentScope( void AssignLocalGradientToParentScope(
...@@ -204,7 +316,8 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -204,7 +316,8 @@ class ConditionalBlockGradOp : public ConditionalOp {
for (size_t i = 0; i < outside_grads.size(); ++i) { for (size_t i = 0; i < outside_grads.size(); ++i) {
const std::string &outside_grad_name = outside_grads[i]; const std::string &outside_grad_name = outside_grads[i];
const std::string &inside_grad_name = inside_grads[i]; const std::string &inside_grad_name = inside_grads[i];
VLOG(4) << "inside_grad_name = " << inside_grad_name VLOG(4) << "[assign local]"
<< "inside_grad_name = " << inside_grad_name
<< ", outside_grad_name = " << outside_grad_name; << ", outside_grad_name = " << outside_grad_name;
framework::Variable *outside_var = framework::Variable *outside_var =
parent_scope.FindVar(outside_grad_name); parent_scope.FindVar(outside_grad_name);
...@@ -237,7 +350,8 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -237,7 +350,8 @@ class ConditionalBlockGradOp : public ConditionalOp {
for (size_t i = 0; i < outside_grads.size(); ++i) { for (size_t i = 0; i < outside_grads.size(); ++i) {
const std::string &outside_grad_name = outside_grads[i]; const std::string &outside_grad_name = outside_grads[i];
const std::string &input_name = inputs[i]; const std::string &input_name = inputs[i];
VLOG(4) << "input_name = " << input_name VLOG(4) << "[assign zero]"
<< "input_name = " << input_name
<< ", outside_grad_name = " << outside_grad_name; << ", outside_grad_name = " << outside_grad_name;
framework::Variable *input_var = scope.FindVar(input_name); framework::Variable *input_var = scope.FindVar(input_name);
if (input_var == nullptr) { if (input_var == nullptr) {
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_helper.h" #include "paddle/phi/backends/onednn/onednn_helper.h"
#endif #endif
...@@ -157,6 +158,13 @@ void TransferLayoutKernel(const Context& dev_ctx, ...@@ -157,6 +158,13 @@ void TransferLayoutKernel(const Context& dev_ctx,
VLOG(10) << "TransDataLayout from " << static_cast<DataLayout>(src_layout) VLOG(10) << "TransDataLayout from " << static_cast<DataLayout>(src_layout)
<< " -> " << static_cast<DataLayout>(dst_layout); << " -> " << static_cast<DataLayout>(dst_layout);
VLOG_IF(10, x.initialized()) << "TransDataLayout from " << x.layout();
if (x.layout() == static_cast<DataLayout>(dst_layout)) {
VLOG(10) << "No need to transform, already is " << x.layout();
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TransferLayoutMKLDNN<Context>(dev_ctx, TransferLayoutMKLDNN<Context>(dev_ctx,
x, x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册