From 5d5cb256c1392c955bde74b8833ae6a5b088daea Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 9 Feb 2023 14:28:19 +0800 Subject: [PATCH] fix gc bug and start interceptor (#50344) --- .../fleet_executor/fleet_executor.cc | 22 ++++++++++--------- .../fleet_executor/start_interceptor.cc | 5 +++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index ec95f2a146..915b1f8280 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -111,7 +111,15 @@ void PreventVarsDelete( std::vector GetUnusedVarsAfterWhile( const framework::ProgramDesc& program_desc, TaskNode* cond_task, - const std::vector vars_not_gc) { + const std::vector& vars_not_gc) { + // NOTE: Since while op won't appear in task node, in order to analyze + // the vars which should be free after calling while op, we rebuild the + // whole program and get the unused vars after calling while op. + // The vars in while block should not be free until the while op is finished. + // In a word, the vars need to be free after while op is: + // 1. Vars in parent block and being used in while block. + // 2. Local vars only defined in while block. + // The unused vars above will be free in cond interceptor. std::vector while_block_vars; std::vector> ops; for (const auto& desc : program_desc.Block(0).AllOps()) { @@ -124,6 +132,9 @@ std::vector GetUnusedVarsAfterWhile( for (const auto& var_name : pair.second) { while_block_vars.emplace_back(var_name); } + for (auto& var : program_desc.Block(1).AllVars()) { + while_block_vars.emplace_back(var->Name()); + } } } return while_block_vars; @@ -178,13 +189,6 @@ void FleetExecutor::Init( auto global_unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); - // Analyse the unused vars in block 1. - std::unordered_map> - sub_unused_vars; - if (program_desc.Size() > 1) { - sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {}); - PreventVarsDelete(&sub_unused_vars, while_block_vars); - } for (auto& unique_op : ops) { unique_op.release(); } @@ -199,8 +203,6 @@ void FleetExecutor::Init( for (auto task_node : task_nodes) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { task_node->SetUnusedVars(global_unused_vars); - } else { - task_node->SetUnusedVars(sub_unused_vars); } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc index b5f3bcb240..b9ce4fabed 100644 --- a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc @@ -68,14 +68,15 @@ void StartInterceptor::SendDataReadyToDownStream() { } if (finish_count_ == batch_size_) { for (int64_t i = 0; i < batch_size_; ++i) { + int64_t scope_id = step_ % node_->max_run_times(); for (auto& outs : out_buffs_) { auto down_id = outs.first; InterceptorMessage ready_msg; ready_msg.set_message_type(DATA_IS_READY); - ready_msg.set_scope_idx(step_); + ready_msg.set_scope_idx(scope_id); VLOG(3) << "StartInterceptor " << interceptor_id_ << " Send data_is_ready msg to " << down_id - << " in scope: " << step_; + << " in scope: " << scope_id; Send(down_id, ready_msg); } step_++; -- GitLab