diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index cb6fb8c1a7f8f72b827584f369fc9daa873388ca..dcc6bb210ceda670ae08694803c31dd44dc90316 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -33,13 +33,11 @@ namespace framework { // NOTE(Aurelius84): Need a better strategy to determine it. static constexpr size_t kHostNumThreads = 4; -InterpreterCore::InterpreterCore(const platform::Place& place, - const ProgramDesc& main_prog, +InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, VariableScope* global_scope, - const std::vector& feed_names, - const std::vector& fetch_names) + const std::vector& feed_names) : place_(place), - main_program_(main_prog), + block_(block), global_scope_(global_scope), stream_analyzer_(place), async_work_queue_(kHostNumThreads, &main_thread_blocker_) { @@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, exception_notifier_ = main_thread_blocker_.RegisterEvent( kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); - // Step1: add feedop and fetchop to main_program - AddFetch(fetch_names); - // prune // optmize graph pass @@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, // convert to run graph } -void InterpreterCore::AddFetch(const std::vector& fetch_names) { - auto* fetch_holder = main_program_.MutableBlock(0)->Var("fetch_vars"); - fetch_holder->SetType(proto::VarType::FETCH_LIST); - fetch_holder->SetPersistable(true); - - int i = 0; - for (auto& fetch_name : fetch_names) { - // append fetch op - auto* op = main_program_.MutableBlock(0)->AppendOp(); - op->SetType("fetch_v2"); - op->SetInput("X", {fetch_name}); - op->SetOutput("Out", {"fetch_vars"}); - op->SetAttr("col", {static_cast(i)}); - op->CheckAttrs(); - i++; - } -} - paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_tensors) { auto FeedInput = [&] { @@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run( }; if (is_build_ == false) { - paddle::framework::interpretercore::build_variable_scope(main_program_, - global_scope_); + paddle::framework::interpreter::build_variable_scope(*block_, + global_scope_); FeedInput(); - paddle::framework::interpretercore::build_op_func_list( - place_, main_program_, &vec_func_list_, global_scope_); + paddle::framework::interpreter::build_op_func_list( + place_, *block_, &vec_func_list_, global_scope_); is_build_ = true; // convert vec func_list to graph Convert(); @@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run( } // return Fetch Tensors - auto* fetch_var = global_scope_->Var("fetch_vars"); + auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); return *(fetch_var->GetMutable()); } @@ -172,8 +149,7 @@ void InterpreterCore::Convert() { std::vector vec_temp; for (auto& item : vec_instruction_[i].Outputs()) { for (auto id : item.second) { - vec_temp = - interpretercore::merge_vector(vec_temp, input_var2op_info_[id]); + vec_temp = interpreter::merge_vector(vec_temp, input_var2op_info_[id]); } } @@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions( [&, next_id] { RunInstructionAsync(next_id); }); } } - auto direct_run_ops = interpretercore::merge_vector( - next_instr.SyncRunIds(), next_instr.DirectRunIds()); + auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(), + next_instr.DirectRunIds()); size_t first_op = 0; for (auto next_id : direct_run_ops) { if (IsReady(next_id)) { @@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare( }; if (is_build_ == false) { - paddle::framework::interpretercore::build_variable_scope(main_program_, - global_scope_); + paddle::framework::interpreter::build_variable_scope(*block_, + global_scope_); FeedInput(); - paddle::framework::interpretercore::build_op_func_list( - place_, main_program_, &vec_func_list_, global_scope_); + paddle::framework::interpreter::build_op_func_list( + place_, *block_, &vec_func_list_, global_scope_); is_build_ = true; // convert vec func_list to graph Convert(); diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index c91acb7827da89b77a250a9adc0fe3e059dbe2aa..e6f623cb72831b61f24d7646a00925381ea28e89 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector>>; class InterpreterCore { public: - InterpreterCore(const platform::Place& place, const ProgramDesc& main_prog, + InterpreterCore(const platform::Place& place, BlockDesc* block, VariableScope* global_scope, - const std::vector& feed_names, - const std::vector& fetch_names); + const std::vector& feed_names); paddle::framework::FetchList Run( const std::vector& feed_tensors); @@ -72,15 +71,14 @@ class InterpreterCore { void RunInstructionAsync(size_t instr_id); void RunNextInstructions(const Instruction& instr_id, std::queue* reserved_next_ops); - void AddFetch(const std::vector& fetch_names); void BuildSkipShareLoDInfo(); bool is_build_; const platform::Place& place_; - ProgramDesc main_program_; - VariableScope* global_scope_; + BlockDesc* block_; // not owned + VariableScope* global_scope_; // not owned std::vector vec_func_list_; std::vector vec_instruction_; // deconstruct before OpFuncNode @@ -88,7 +86,6 @@ class InterpreterCore { InstructionInfo instruction_info_; std::vector dependecy_count_; std::vector> input_var2op_info_; - std::vector ref_coun_info_; std::vector vec_meta_info_; std::vector feed_names_; @@ -97,7 +94,7 @@ class InterpreterCore { StreamAnalyzer stream_analyzer_; EventManager event_manager_; EventsWaiter main_thread_blocker_; - interpretercore::AsyncWorkQueue async_work_queue_; + interpreter::AsyncWorkQueue async_work_queue_; details::ExceptionHolder exception_holder_; std::shared_ptr exception_notifier_{nullptr}; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index cc5ed6eff162a111d6423adc64fb7e4660002e7a..7b52cc991e14b275e8fd34439720bdb688d86c7f 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { -namespace interpretercore { +namespace interpreter { using VariableIdMap = std::map>; AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps( @@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place, } } -void build_variable_scope(const framework::ProgramDesc& pdesc, +void build_variable_scope(const framework::BlockDesc& block, VariableScope* var_scope) { - auto& global_block = pdesc.Block(0); - - for (auto& var_desc : global_block.AllVars()) { + for (auto& var_desc : block.AllVars()) { auto var_name = var_desc->Name(); if (var_name == framework::kEmptyVarName) { continue; @@ -360,9 +358,9 @@ std::tuple apply_place_transform_for_var( std::vector apply_data_transform( const OpKernelType& expected_kernel_key, const platform::Place& place, - VariableValueMap& ins_map_temp, VariableScope* var_scope, - OpFuncNode& op_func_node) { - auto& op_base = op_func_node.operator_base_; + VariableValueMap* ins_map_temp, VariableScope* var_scope, + OpFuncNode* op_func_node) { + auto& op_base = op_func_node->operator_base_; PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( "op_base is null, please pass a valid " "op_base in apply_data_transform.")); @@ -372,7 +370,7 @@ std::vector apply_data_transform( no_data_transform_index; // record the no need transform variable index. std::vector copy_func_nodes; // return all the copy opfuncnode. - for (auto& var_name_item : ins_map_temp) { + for (auto& var_name_item : *ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; auto& var_name = inputs_names[var_name_item.first].at(i); @@ -394,8 +392,8 @@ std::vector apply_data_transform( std::tie(new_var_name, copy_op_func_node) = apply_place_transform_for_var( kernel_type_for_var, expected_kernel_key, place, var_name, - var_name_item.first, op_func_node, var, var_scope); - op_func_node.input_index[var_name_item.first][i] = + var_name_item.first, *op_func_node, var, var_scope); + op_func_node->input_index[var_name_item.first][i] = var_scope->VarId(new_var_name); copy_func_nodes.push_back(copy_op_func_node); var_name_item.second[i] = var_scope->Var(new_var_name); @@ -414,23 +412,22 @@ std::vector apply_data_transform( } } } - op_func_node.no_data_transform_index = std::move(no_data_transform_index); + op_func_node->no_data_transform_index = std::move(no_data_transform_index); return copy_func_nodes; } void build_op_func_list(const platform::Place& place, - const framework::ProgramDesc& pdesc, + const framework::BlockDesc& block, std::vector* vec_func_list, VariableScope* var_scope) { - auto& global_block = pdesc.Block(0); auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); - // Step 1: create all ops for global block. - auto ops = create_all_ops(global_block); - auto unused_var_map = get_unused_vars(global_block, ops); + // Step 1: create all ops for current block. + auto ops = create_all_ops(block); + auto unused_var_map = get_unused_vars(block, ops); size_t ops_index = 0; - for (auto& op : global_block.AllOps()) { + for (auto& op : block.AllOps()) { VLOG(6) << "Build OpFuncNode from : " << op->Type(); auto op_base = ops[ops_index++]; @@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place, // apply_data_transform. op_func_node.operator_base_ = op_base; copy_op_to_insert = apply_data_transform( - expected_kernel_key, place, ins_map_temp, var_scope, op_func_node); + expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node); for (auto& item : copy_op_to_insert) { vec_func_list->push_back(item); } @@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place, } } +void add_fetch(const std::vector& fetch_names, + framework::BlockDesc* block) { + auto* fetch_holder = block->Var(kFetchVarName); + fetch_holder->SetType(proto::VarType::FETCH_LIST); + fetch_holder->SetPersistable(true); + + int i = 0; + for (auto& fetch_name : fetch_names) { + // append fetch op + auto* op = block->AppendOp(); + op->SetType("fetch_v2"); + op->SetInput("X", {fetch_name}); + op->SetOutput("Out", {kFetchVarName}); + op->SetAttr("col", {static_cast(i)}); + op->CheckAttrs(); + i++; + } +} + std::vector merge_vector(const std::vector& first, const std::vector& second) { std::vector out(first.size() + second.size()); @@ -590,6 +606,6 @@ std::vector merge_vector(const std::vector& first, return out; } -} // namespace interpretercore +} // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 976826800f4a47af97e0bc9a255871758bbe38b4..375fed2356a0100d70661439b7b3d406d6dc59a2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -48,9 +48,10 @@ namespace paddle { namespace framework { -namespace interpretercore { +namespace interpreter { using AtomicVectorSizeT = std::vector>>; +static constexpr char kFetchVarName[] = "fetch_vars"; class AsyncWorkQueue { public: @@ -96,17 +97,20 @@ class AsyncWorkQueue { std::string get_memcpy_type(const platform::Place& src_place, const platform::Place& dst_place); -void build_variable_scope(const framework::ProgramDesc& pdesc, +void build_variable_scope(const framework::BlockDesc& block, VariableScope* var_scope); void build_op_func_list(const platform::Place& place, - const framework::ProgramDesc& pdesc, + const framework::BlockDesc& block, std::vector* vec_func_list, VariableScope* var_scope); +void add_fetch(const std::vector& fetch_names, + framework::BlockDesc* block); + std::vector merge_vector(const std::vector& first, const std::vector& second); -} // namespace interpretercore +} // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 0432aa33d7dcbaabaf620b17c68e687eaf098ef3..f4832cb2835400f9868bfd56d122f90fcb666e0d 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -776,7 +776,7 @@ class Instruction { std::vector> vec_inplace_in_to_out_; }; -namespace interpretercore { +namespace interpreter { static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyD2H[] = "memcpy_d2h"; @@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) { static bool IsMemcpyD2H(const Instruction& instr) { return instr.OpBase()->Type() == kMemcpyD2H; } -} // namespace interpretercore +} // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 474be9e889d2af446ca8335e08b68328ebfe02eb..6d3b531c50b2e77b1ec66c5e73d2a94274fa5ca4 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, // run startup program std::vector vec_func_list; - paddle::framework::interpretercore::build_op_func_list( - place_, startup_prog, &vec_func_list, &global_scope_); + paddle::framework::interpreter::build_op_func_list( + place_, startup_prog.Block(0), &vec_func_list, &global_scope_); } paddle::framework::FetchList StandaloneExecutor::Run( @@ -96,8 +96,15 @@ std::shared_ptr StandaloneExecutor::GetInterpreterCore( if (iter == interpretercores_.end()) { VLOG(3) << "create interpreter_core for " << oss.str(); - auto core = std::make_shared( - place_, main_prog_, &global_scope_, feed_names, fetch_names); + // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a + // new program. + auto new_prog = std::make_shared(main_prog_); + auto* block = new_prog->MutableBlock(0); + interpreter::add_fetch(fetch_names, block); + + auto core = std::make_shared(place_, block, &global_scope_, + feed_names); + programs_.emplace(oss.str(), new_prog); interpretercores_.emplace(oss.str(), core); return core; } else { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index ba1c7df45c9d2f6a8823590fe2a8c3f61b6770e2..6d48a8514274231da00a91035cee55cc9ecaba22 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase { Scope* outer_scope_; VariableScope global_scope_; + std::unordered_map> programs_; std::unordered_map> interpretercores_; }; diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index d30f27169cc43d16e237eaf42637e2ad82a638ac..23b61dd3d5ee7b7f84e12b8ad32b56604b2944e5 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { auto& op_type = op_func_node.operator_base_->Type(); auto* dev_ctx = op_func_node.dev_ctx_; - if (op_type == interpretercore::kMemcpyH2D) { + if (op_type == interpreter::kMemcpyH2D) { VLOG(3) << "Get dev_ctx from d2h_context_pool_"; dev_ctx = d2h_ctx_pool_.Get(place_); - } else if (op_type == interpretercore::kMemcpyD2H) { + } else if (op_type == interpreter::kMemcpyD2H) { VLOG(3) << "Get dev_ctx from h2d_context_pool_"; dev_ctx = h2d_ctx_pool_.Get(place_); } @@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, const Instruction& next_instr) { return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || - interpretercore::IsMemcpyD2H(cur_instr) || - interpretercore::IsMemcpyH2D(next_instr)); + interpreter::IsMemcpyD2H(cur_instr) || + interpreter::IsMemcpyH2D(next_instr)); } platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {