From 5303b66b10653482d752fa90421d56f8c4d1ef95 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 12 Oct 2022 16:24:31 +0800 Subject: [PATCH] clean code of interpretercore (#46891) * refactor * refine code --- .../framework/new_executor/data_transfer.cc | 4 +- .../framework/new_executor/interpretercore.cc | 201 ++++++++---------- .../framework/new_executor/interpretercore.h | 60 +++--- .../new_executor/interpretercore_util.cc | 76 +++---- .../new_executor/interpretercore_util.h | 31 ++- .../new_executor/new_executor_defs.h | 2 +- 6 files changed, 170 insertions(+), 204 deletions(-) diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 60cfcdf443a..d4402697309 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -378,7 +378,7 @@ std::shared_ptr TransferDevice(const std::string& var_name, "Required src_place shall be different with dst_place, " "but received same place: %s", src_place)); - if (IsSupportedHetePlace(dst_place)) { + if (IsSupportedHeterPlace(dst_place)) { op_type = kMemcpyH2D; int dst_place_type = platform::is_gpu_place(dst_place) ? 0 : platform::is_npu_place(dst_place) ? 1 @@ -387,7 +387,7 @@ std::shared_ptr TransferDevice(const std::string& var_name, : platform::is_custom_place(dst_place) ? 6 : -1; attr_map = {{"dst_place_type", dst_place_type}}; - } else if (IsSupportedHetePlace(src_place)) { + } else if (IsSupportedHeterPlace(src_place)) { op_type = kMemcpyD2H; int dst_place_type = platform::is_cpu_place(dst_place) ? 0 : platform::is_cuda_pinned_place(dst_place) ? 1 diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 9779f140941..cdb5fcc58aa 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -57,6 +57,50 @@ constexpr const char* kTaskCompletion = "TaskCompletion"; namespace paddle { namespace framework { +inline void SetDeviceId(const platform::Place& place) { + // TODO(zhiqiu): reduce the cost + if (platform::is_gpu_place(place)) { +#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CUDA support.", + place)); +#else + auto dev_id = place.device; + platform::SetDeviceId(dev_id); +#endif + } else if (platform::is_xpu_place(place)) { +#ifndef PADDLE_WITH_XPU + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with XPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetXPUDeviceId(dev_id); +#endif + } else if (platform::is_npu_place(place)) { +#ifndef PADDLE_WITH_ASCEND_CL + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with NPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetNPUDeviceId(dev_id); +#endif + } else if (platform::is_custom_place(place)) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CustomDevice support.", + place)); +#else + phi::DeviceManager::SetDevice(place); +#endif + } +} + // TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by // constructing an interpreter::ExecutionConfig InterpreterCore::InterpreterCore(const platform::Place& place, @@ -71,8 +115,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, stream_analyzer_(place) { VLOG(4) << "InterpreterCore(): " << this << " on " << place_; - is_build_ = false; - exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); @@ -87,12 +129,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, local_scope_ = local_scope; } var_scope_.SetLocalScope(local_scope_); - - // prune - - // optmize graph pass - - // convert to run graph } InterpreterCore::~InterpreterCore() { @@ -111,11 +147,8 @@ InterpreterCore::~InterpreterCore() { interpreter::CostInfo InterpreterCore::DryRun( const std::vector& feed_names, const std::vector& feed_tensors) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(place_)) { - platform::SetDeviceId(place_.device); - } -#endif + SetDeviceId(place_); + Prepare(feed_names, feed_tensors, true); interpreter::CostInfo cost_info; { @@ -135,7 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun( platform::DeviceContextPool::Instance().Get(place_)->Wait(); } - if (execution_config_.create_local_scope) { + if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } @@ -145,11 +178,7 @@ interpreter::CostInfo InterpreterCore::DryRun( paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(place_)) { - platform::SetDeviceId(place_.device); - } -#endif + SetDeviceId(place_); #ifdef PADDLE_WITH_MKLDNN platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -181,7 +210,7 @@ paddle::framework::FetchList InterpreterCore::Run( } #endif } - if (execution_config_.create_local_scope) { + if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } @@ -196,11 +225,7 @@ paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(place_)) { - platform::SetDeviceId(place_.device); - } -#endif + SetDeviceId(place_); #ifdef PADDLE_WITH_MKLDNN platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -208,17 +233,17 @@ paddle::framework::FetchList InterpreterCore::Run( if (!is_build_) { LOG_FIRST_N(INFO, 1) << "New Executor is Running."; - paddle::framework::interpreter::build_variable_scope( - block_, &var_scope_, execution_config_.create_local_scope); + paddle::framework::interpreter::BuildVariableScope( + block_, &var_scope_, HasLocalScope()); std::vector op_func_nodes; - paddle::framework::interpreter::build_op_func_list( + paddle::framework::interpreter::BuildOpFuncList( place_, block_, execution_config_.skip_gc_vars, &op_func_nodes, &var_scope_, - execution_config_.create_local_scope, + HasLocalScope(), execution_config_.used_for_jit); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); @@ -248,13 +273,13 @@ paddle::framework::FetchList InterpreterCore::Run( #endif } - if (execution_config_.create_local_scope) { + if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } + // return Fetch Tensors - Scope* inner_scope = execution_config_.create_local_scope - ? local_scope_ - : var_scope_.GetMutableScope(); + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); if (fetch_var) { return std::move(*fetch_var->GetMutable()); @@ -327,9 +352,8 @@ std::shared_ptr InterpreterCore::GetWorkQueue() { } void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { - Scope* inner_scope = execution_config_.create_local_scope - ? local_scope_ - : var_scope_.GetMutableScope(); + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); VariableValueMap ins_map; for (auto& var_name_item : instr_node->Inputs()) { std::vector input_vars; @@ -355,9 +379,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { // set runtime_ctx and infershape_ctx_ if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in // kernel - Scope* local_scope = execution_config_.create_local_scope - ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() + : var_scope_.GetMutableScope(); instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); } else { instr_node->ResetContext(ins_map, outs_map); @@ -387,9 +410,8 @@ void InterpreterCore::BuildInplace() { } } - Scope* local_scope = execution_config_.create_local_scope - ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() + : var_scope_.GetMutableScope(); std::vector> input_var2op(var_scope_.VarSize()); for (Instruction& instr : vec_instruction_) { for (auto& item : instr.Inputs()) { @@ -524,9 +546,8 @@ void InterpreterCore::Convert( } for (auto var_id : gc_check_vars) { - Scope* inner_scope = execution_config_.create_local_scope - ? local_scope_ - : var_scope_.GetMutableScope(); + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); paddle::framework::Variable* var = inner_scope->FindVar(var_scope_.GetNameById(var_id)); if (var->IsType() || var->IsType() || @@ -629,56 +650,11 @@ void InterpreterCore::BuildSkipShareLoDInfo() { } } -inline void SetDeviceId(const platform::Place& place) { - // TODO(zhiqiu): reduce the cost - if (platform::is_gpu_place(place)) { -#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - PADDLE_THROW(platform::errors::Unavailable( - "Cannot run operator on place %s, please recompile paddle or " - "reinstall Paddle with CUDA support.", - place)); -#else - auto dev_id = place.device; - platform::SetDeviceId(dev_id); -#endif - } else if (platform::is_xpu_place(place)) { -#ifndef PADDLE_WITH_XPU - PADDLE_THROW(platform::errors::Unavailable( - "Cannot run operator on place %s, please recompile paddle or " - "reinstall Paddle with XPU support.", - place)); -#else - auto dev_id = place.device; - platform::SetXPUDeviceId(dev_id); -#endif - } else if (platform::is_npu_place(place)) { -#ifndef PADDLE_WITH_ASCEND_CL - PADDLE_THROW(platform::errors::Unavailable( - "Cannot run operator on place %s, please recompile paddle or " - "reinstall Paddle with NPU support.", - place)); -#else - auto dev_id = place.device; - platform::SetNPUDeviceId(dev_id); -#endif - } else if (platform::is_custom_place(place)) { -#ifndef PADDLE_WITH_CUSTOM_DEVICE - PADDLE_THROW(platform::errors::Unavailable( - "Cannot run operator on place %s, please recompile paddle or " - "reinstall Paddle with CustomDevice support.", - place)); -#else - phi::DeviceManager::SetDevice(place); -#endif - } -} - void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); - Scope* local_scope = execution_config_.create_local_scope - ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() + : var_scope_.GetMutableScope(); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); SetDeviceId(place); @@ -800,8 +776,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::ExecuteInstructionList( const std::vector& vec_instr) { interpreter::ResetAtomicGuard guard(&deps_, &refs_); - unfinished_op_numer_ = vec_instr.size(); - if (unfinished_op_numer_ == 0) { + unfinished_op_number_ = vec_instr.size(); + if (unfinished_op_number_ == 0) { VLOG(4) << "No op to run, return"; return; } @@ -878,8 +854,12 @@ void InterpreterCore::RunNextInstructions( [this, next_id] { RunInstructionAsync(next_id); }); } } - auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(), - next_instr.DirectRunIds()); + + std::vector direct_run_ops = next_instr.SyncRunIds(); + direct_run_ops.insert(direct_run_ops.end(), + next_instr.DirectRunIds().begin(), + next_instr.DirectRunIds().end()); + int64_t first_op = -1; for (auto next_id : direct_run_ops) { if (IsReady(next_id)) { @@ -949,9 +929,9 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { return; } - VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_; - if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) == - 1)) { + VLOG(4) << "unfinished_op_number_: " << unfinished_op_number_; + if (UNLIKELY(unfinished_op_number_.fetch_sub( + 1, std::memory_order_relaxed) == 1)) { if (completion_notifier_ != nullptr) { completion_notifier_->NotifyEvent(); } @@ -961,8 +941,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { } } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void InterpreterCore::RecordStreamForGC(const Instruction& instr) { +#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + PADDLE_THROW(platform::errors::Unimplemented( + "RecordStreamForGC is only implemented when compiled with GPU.")); +#else if (!IsInterpretercoreFastGCEnabled() || instr.KernelType() != OpFuncType::kQueueAsync) { return; @@ -1053,8 +1036,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { framework::ToTypeName(var->Type()))); } } -} #endif +} void InterpreterCore::CheckGC(const Instruction& instr) { platform::RecordEvent record( @@ -1106,17 +1089,17 @@ void InterpreterCore::Prepare(const std::vector& feed_names, }; if (!is_build_) { - paddle::framework::interpreter::build_variable_scope( - block_, &var_scope_, execution_config_.create_local_scope); + paddle::framework::interpreter::BuildVariableScope( + block_, &var_scope_, HasLocalScope()); FeedInput(); std::vector op_func_nodes; - paddle::framework::interpreter::build_op_func_list( + paddle::framework::interpreter::BuildOpFuncList( place_, block_, execution_config_.skip_gc_vars, &op_func_nodes, &var_scope_, - execution_config_.create_local_scope, + HasLocalScope(), execution_config_.used_for_jit); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); @@ -1124,7 +1107,7 @@ void InterpreterCore::Prepare(const std::vector& feed_names, Convert(&op_func_nodes); } // NOTE: Because feed_tensor will be GC after - // paddle::framework::build_op_func_list, so we should + // paddle::framework::BuildOpFuncList, so we should // call FeedInput again. if (prepare_feed) { FeedInput(); @@ -1138,6 +1121,8 @@ void InterpreterCore::SetFeedVarsInplaceSkip( } } +bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; } + std::shared_ptr CreateInterpreterCore( const platform::Place& place, const ProgramDesc& prog, @@ -1145,11 +1130,11 @@ std::shared_ptr CreateInterpreterCore( const std::vector& fetch_names, const std::set& skip_gc_vars) { std::shared_ptr core = nullptr; - // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy + // NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy // a new program. auto new_prog = std::make_shared(prog); auto* block = new_prog->MutableBlock(0); - interpreter::add_fetch(fetch_names, block); + interpreter::AddFetch(fetch_names, block); core = std::make_shared(place, *block, skip_gc_vars, scope); core->SetCopyProgram(new_prog); diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index fa097a566b1..8e63c970e1e 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -68,45 +68,42 @@ class InterpreterCore { void reset_scope(Scope* new_scope); private: - bool BuildInplaceCheckVarIsOnlyInput( - const std::vector>& input_var2op, size_t var_index); - - std::shared_ptr GetWorkQueue(); - + // build graph + void Convert(std::vector* op_func_nodes); + void BuildOperatorDependences(); void BuildAndCacheInstructionCtx(Instruction* instr_node); + void BuildSkipShareLoDInfo(); + // inplace void BuildInplace(); + bool BuildInplaceCheckVarIsOnlyInput( + const std::vector>& input_var2op, size_t var_index); + void SetFeedVarsInplaceSkip(const std::vector& feed_names); - void BuildOperatorDependences(); - - void ClearLoDTensorArrayInLocalScope(); - - void Convert(std::vector* op_func_nodes); - - void RunInstruction(const Instruction& instr_node); - + // execution void ExecuteInstructionList(const std::vector& vec_instr); - + void RunInstructionAsync(size_t instr_id); + void RunInstruction(const Instruction& instr_node); + void RunNextInstructions(const Instruction& instr_id, + std::queue* reserved_next_ops); + // only used when program contains no feed op void Prepare(const std::vector& feed_names, const std::vector& feed_tensors, bool prepare_feed); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // gc void RecordStreamForGC(const Instruction& instr); -#endif - void CheckGC(const Instruction& instr); + void ClearLoDTensorArrayInLocalScope(); - void RunInstructionAsync(size_t instr_id); - void RunNextInstructions(const Instruction& instr_id, - std::queue* reserved_next_ops); - - void BuildSkipShareLoDInfo(); + // workqueue + std::shared_ptr GetWorkQueue(); - void SetFeedVarsInplaceSkip(const std::vector& feed_names); + // scope + bool HasLocalScope() const; private: - bool is_build_; + bool is_build_{false}; platform::Place place_; const BlockDesc& block_; // not owned @@ -127,11 +124,7 @@ class InterpreterCore { std::vector vec_instruction_; // deconstruct before OpFuncNode - // last_live_ops_[i] contains the id of operators that last access var[i] - std::map> last_live_ops_; - - std::vector dependecy_count_; - std::atomic unfinished_op_numer_{0}; + std::atomic unfinished_op_number_{0}; VariableScope var_scope_; Scope* local_scope_{nullptr}; // not owned @@ -145,8 +138,13 @@ class InterpreterCore { std::unique_ptr gc_; - std::future> atomic_deps_; - std::future> atomic_var_ref_; + // last_live_ops_[i] contains the id of operators that last access the i-th + // var + std::map> last_live_ops_; + + // dependecy_count_[i] contains the number of dependencies that the i-th op + // need to wait + std::vector dependecy_count_; std::vector> deps_; std::vector> refs_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 1d70bb6039b..273a7ee8bc4 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -122,8 +122,8 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { std::unordered_map> -get_unused_vars(const BlockDesc& block, - const std::vector>& ops) { +GetUnusedVars(const BlockDesc& block, + const std::vector>& ops) { std::unordered_map var_op_idx_map; for (size_t i = 0; i < ops.size(); ++i) { @@ -166,17 +166,17 @@ get_unused_vars(const BlockDesc& block, for (auto& name_op_idx_pair : var_op_idx_map) { auto& name = name_op_idx_pair.first; size_t op_idx = name_op_idx_pair.second; - - result[ops[op_idx].get()].emplace_back(name); - VLOG(4) << ops[op_idx].get()->Type() << " " << name; + auto op = ops[op_idx].get(); + result[op].emplace_back(name); + VLOG(4) << op->Type() << " " << name; } VLOG(4) << "gc map size:" << result.size(); return result; } -void build_variable_scope(const framework::BlockDesc& block, - VariableScope* var_scope, - bool use_local_scope) { +void BuildVariableScope(const framework::BlockDesc& block, + VariableScope* var_scope, + bool use_local_scope) { VLOG(3) << "Creating Variables"; auto inner_scope = var_scope->GetMutableScope(); @@ -214,8 +214,8 @@ void build_variable_scope(const framework::BlockDesc& block, } } -void create_all_ops(const framework::BlockDesc& block, - std::vector>* ops) { +void CreateAllOps(const framework::BlockDesc& block, + std::vector>* ops) { for (auto& op : block.AllOps()) { auto op_type = op->Type(); VLOG(8) << "CreateOp from : " << op_type; @@ -289,9 +289,9 @@ std::tuple BuildVariableMap( return std::make_tuple(name2var, name2id); } -void apply_device_guard(const OperatorBase* op_base, - const platform::Place& place, - OpKernelType* expected_kernel_key) { +void ApplyDeviceGuard(const OperatorBase* op_base, + const platform::Place& place, + OpKernelType* expected_kernel_key) { bool need_change_place = (op_base->HasAttr("op_device") && (op_base->Attr("op_device").length() > 0)); @@ -352,7 +352,7 @@ void apply_device_guard(const OperatorBase* op_base, } } -void deal_operator_base(const platform::Place& place, +void HandleOperatorBase(const platform::Place& place, const VariableScope* var_scope, std::shared_ptr op_base, OpFuncNode* op_func_node, @@ -361,7 +361,7 @@ void deal_operator_base(const platform::Place& place, auto* dev_ctx = pool.Get(place); // input, output is prepared. set the other attributes. op_func_node->operator_base_ = op_base; - if (IsSupportedHetePlace(place)) { + if (IsSupportedHeterPlace(place)) { op_func_node->type_ = OpFuncType::kQueueAsync; } else if (platform::is_cpu_place(place)) { op_func_node->type_ = OpFuncType::kQueueSync; @@ -382,19 +382,19 @@ void deal_operator_base(const platform::Place& place, op_func_node->dev_ctx_ = dev_ctx; } -void build_op_func_list(const platform::Place& place, - const framework::BlockDesc& block, - const std::set& skip_gc_vars, - std::vector* vec_func_list, - VariableScope* var_scope, - bool use_local_scope, - bool used_for_jit) { +void BuildOpFuncList(const platform::Place& place, + const framework::BlockDesc& block, + const std::set& skip_gc_vars, + std::vector* vec_func_list, + VariableScope* var_scope, + bool use_local_scope, + bool used_for_jit) { Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() : var_scope->GetMutableScope(); std::vector> ops_unique; // its elements will be moved to vec_func_list // Step 1: create all ops for current block. - create_all_ops(block, &ops_unique); + CreateAllOps(block, &ops_unique); if (!used_for_jit) { // If gc is enabled and block size > 1 @@ -415,7 +415,7 @@ void build_op_func_list(const platform::Place& place, for (auto& op_unique : ops_unique) { ops.emplace_back(std::move(op_unique)); } - auto unused_var_map = get_unused_vars(block, ops); + auto unused_var_map = GetUnusedVars(block, ops); bool flag_log_is_printed = false; for (size_t i = 0; i < ops.size(); ++i) { @@ -485,10 +485,10 @@ void build_op_func_list(const platform::Place& place, try { if (dynamic_cast(op) == nullptr) { + VLOG(4) << "HandleOperatorBase"; // op is not a operatorwithkernel, so direcly run OperatorBase::Run() - deal_operator_base( + HandleOperatorBase( place, var_scope, ops[i], &op_func_node, local_scope); - VLOG(4) << "deal_operator_base"; } else { VLOG(4) << "OP is not null"; auto op_with_kernel = const_cast( @@ -522,7 +522,7 @@ void build_op_func_list(const platform::Place& place, op_with_kernel->GetExpectedKernelType(exec_ctx); VLOG(4) << "get expected_kernel_key"; // change device by the device_guard() - apply_device_guard(op, place, &expected_kernel_key); + ApplyDeviceGuard(op, place, &expected_kernel_key); VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // step 2. select op kernel @@ -565,7 +565,7 @@ void build_op_func_list(const platform::Place& place, dev_ctx = pool.Get(kernel_type.place_); } op_func_node.dev_ctx_ = dev_ctx; - if (IsSupportedHetePlace(kernel_type.place_)) { + if (IsSupportedHeterPlace(kernel_type.place_)) { op_func_node.type_ = OpFuncType::kQueueAsync; } else if (platform::is_cpu_place(kernel_type.place_)) { op_func_node.type_ = OpFuncType::kQueueSync; @@ -667,7 +667,7 @@ void build_op_func_list(const platform::Place& place, vec_func_list->emplace_back(op_func_node); - // gc--------------------------------------------------------------------------- + // gc--------------------------------------------- auto iter = unused_var_map.find(op); if (iter == unused_var_map.end()) { interpreter::LogDeviceMemoryStats(place); @@ -702,8 +702,8 @@ void build_op_func_list(const platform::Place& place, memory::Release(place); } -void add_fetch(const std::vector& fetch_names, - framework::BlockDesc* block) { +void AddFetch(const std::vector& fetch_names, + framework::BlockDesc* block) { auto* fetch_holder = block->Var(kFetchVarName); fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetPersistable(true); @@ -721,20 +721,6 @@ void add_fetch(const std::vector& fetch_names, } } -std::vector merge_vector(const std::vector& first, - const std::vector& second) { - std::vector out(first.size() + second.size()); - std::merge( - first.begin(), first.end(), second.begin(), second.end(), out.begin()); - - std::vector::iterator it; - it = std::unique(out.begin(), out.end()); - - out.resize(std::distance(out.begin(), it)); - - return out; -} - } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 438ae5eebf6..3e96262407f 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -66,23 +66,20 @@ class AsyncWorkQueue { void LogDeviceMemoryStats(const platform::Place& place); -void build_variable_scope(const framework::BlockDesc& block, - VariableScope* var_scope, - bool use_local_scope = true); - -void build_op_func_list(const platform::Place& place, - const framework::BlockDesc& block, - const std::set& skip_gc_vars, - std::vector* vec_func_list, - VariableScope* scope, - bool use_local_scope = true, - bool used_for_jit = false); - -void add_fetch(const std::vector& fetch_names, - framework::BlockDesc* block); - -std::vector merge_vector(const std::vector& first, - const std::vector& second); +void BuildVariableScope(const framework::BlockDesc& block, + VariableScope* var_scope, + bool use_local_scope = true); + +void BuildOpFuncList(const platform::Place& place, + const framework::BlockDesc& block, + const std::set& skip_gc_vars, + std::vector* vec_func_list, + VariableScope* scope, + bool use_local_scope = true, + bool used_for_jit = false); + +void AddFetch(const std::vector& fetch_names, + framework::BlockDesc* block); } // namespace interpreter } // namespace framework diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index ed534c8c6c4..7fced920a77 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -392,7 +392,7 @@ static bool IsCpuOp(const Instruction& instr) { } // is supported heterogeneous place -static bool IsSupportedHetePlace(const phi::Place& place) { +static bool IsSupportedHeterPlace(const phi::Place& place) { return platform::is_gpu_place(place) || platform::is_npu_place(place) || platform::is_xpu_place(place) || platform::is_ipu_place(place) || platform::is_custom_place(place); -- GitLab