diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 18fba0d19bbdc3c9d07e4c1a8002a010177dc123..3f4d9f6ca4297fe1d21403ec1220a1d6df544264 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -23,6 +23,7 @@ namespace details { inline void NewTempScopeAndInitVars(const std::vector &var_infos, Scope *scope) { + VLOG(3) << "NewTempScopeAndInitVars"; Scope &local_scope = scope->NewScope(); *scope->Var(details::kLocalExecScopeName)->GetMutable() = &local_scope; @@ -43,12 +44,15 @@ inline void NewTempScopeAndInitVars(const std::vector &var_infos, // get RpcContext and remote send and recv op void ProcessGraph(std::vector graphs, Scope *scope) { using RpcCtxMap = operators::distributed::RpcCtxMap; + VLOG(3) << "ProcessGraph"; RpcCtxMap send_varname_to_ctx; RpcCtxMap recv_varname_to_ctx; for (auto i = 0; i < graphs.size(); ++i) { for (auto &node : graphs[i]->Nodes()) { - if (node->IsOp()) { - if (node->Op()->Type() == "send") { + VLOG(3) << "node name " << node->Name(); + std::vector nodes_to_delete; + if (node && node->IsOp()) { + if (node->Name() == "send") { auto send_var_name = node->Op()->Input("X")[0]; auto send_varnames = boost::get>( node->Op()->GetNullableAttr("send_varnames")); @@ -61,8 +65,8 @@ void ProcessGraph(std::vector graphs, Scope *scope) { epmap, height_section); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; - } else if (node->Op()->Type() == "recv") { - auto recv_var_name = node->Op()->Input("X")[0]; + } else if (node->Name() == "recv") { + auto recv_var_name = node->Op()->Output("Out")[0]; auto recv_varnames = boost::get>( node->Op()->GetNullableAttr("recv_varnames")); auto epmap = boost::get>( @@ -70,18 +74,23 @@ void ProcessGraph(std::vector graphs, Scope *scope) { recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, epmap, {}); - graphs[i]->RemoveNode(node); + nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; } + VLOG(3) << "delete all recv ops"; + for (auto *node : nodes_to_delete) { + graphs[i]->RemoveNode(node); + } } } } // init communicator here if (send_varname_to_ctx.size() > 0) { - VLOG(3) << "this is distribute mode, will use "; + VLOG(3) << "this is distribute mode, will use communicator"; operators::distributed::Communicator::Init(send_varname_to_ctx, recv_varname_to_ctx, scope); + operators::distributed::Communicator::GetInstance()->Start(); } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 6c5f246f95b97715989fea2b838d6a23c9c3bbea..6c710abd7a7e79e06f76baf7e193f4cad6158362 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -277,7 +277,7 @@ ParallelExecutor::ParallelExecutor( // ncclOp std::vector async_graphs(places.size()); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { + if (build_strategy.async_mode_) { VLOG(3) << "use local async mode"; temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, @@ -298,7 +298,7 @@ ParallelExecutor::ParallelExecutor( member_->nccl_ctxs_.get()); } #else - if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { + if (build_strategy.async_mode_) { VLOG(3) << "use local async mode"; temp_owned_graph = build_strategy.Apply( std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, @@ -358,7 +358,7 @@ ParallelExecutor::ParallelExecutor( } } - if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { + if (build_strategy.async_mode_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index e800cd5f417028f05e68bfc39d39f05223704ba4..b2bb8fb403056d4b496c68372d12372ec087c987 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/communicator.h" +#include // NOLINT +#include // NOLINT + #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable_helper.h" @@ -28,6 +31,7 @@ namespace distributed { static inline void MergeVars(const std::string &var_name, const std::vector> &vars, Scope *scope) { + VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to one"; PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); auto cpu_place = platform::CPUPlace(); auto &var0 = vars[0]; @@ -67,29 +71,32 @@ std::unique_ptr Communicator::communicator_(nullptr); std::once_flag Communicator::init_flag_; void Communicator::SendThread() { + VLOG("SendThread start!"); while (running_) { std::vector> task_futures; task_futures.reserve(send_varname_to_ctx_.size()); for (auto &iter : send_varname_to_queue_) { - auto send_task = [this, &iter] { - auto &var_name = iter.first; - VLOG(3) << "merge var " << var_name << " and send"; - auto &var_queue = iter.second; - std::vector> vars; - // TODO(qiao): need to be configurable - const size_t max_merge_var_num = 20; - size_t merged_var_num = 0; - while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { - vars.push_back(var_queue->Pop()); - merged_var_num++; - } - MergeVars(var_name, vars, send_scope_.get()); - auto send_functor = distributed::ParameterSend(); - auto &ctx = send_varname_to_ctx_.at(var_name); - send_functor(ctx, *send_scope_, true); - }; - task_futures.emplace_back( - send_threadpool_->enqueue(std::move(send_task))); + auto &var_name = iter.first; + auto &var_queue = iter.second; + if (var_queue->NotEmpty()) { // will block if queue is empty + auto send_task = [this, &var_name, &var_queue] { + VLOG(3) << "merge var " << var_name << " and send"; + std::vector> vars; + // TODO(qiao): need to be configurable + const size_t max_merge_var_num = 20; + size_t merged_var_num = 0; + while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { + vars.push_back(var_queue->Pop()); + merged_var_num++; + } + MergeVars(var_name, vars, send_scope_.get()); + auto send_functor = distributed::ParameterSend(); + auto &ctx = send_varname_to_ctx_.at(var_name); + send_functor(ctx, *send_scope_, true); + }; + task_futures.emplace_back( + send_threadpool_->enqueue(std::move(send_task))); + } } for (auto &task_f : task_futures) { task_f.wait(); @@ -98,6 +105,7 @@ void Communicator::SendThread() { } void Communicator::RecvThread() { + VLOG(3) << "RecvThread start!"; while (running_) { // parallel run recv graph std::vector> task_futures; @@ -115,6 +123,8 @@ void Communicator::RecvThread() { for (auto &task : task_futures) { task.wait(); } + // TODO(qiao) need to be configuable + std::this_thread::sleep_for(std::chrono::milliseconds(200)); } } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index bc753bb75ef110ca6142c63f0fb7dae8ebd1b62e..c93ad02555e6c8cd90d6030a46aac430354522b1 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -68,6 +68,12 @@ class BlockingQueue { return rc; } + bool NotEmpty() { + std::unique_lock lock(mutex_); + recv_cv_.wait(lock, [=] { return !queue_.empty(); }); + return true; + } + size_t Cap() const { std::lock_guard lock(mutex_); return capacity_; diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index 3ecb6960690cbd4679af9ad310f27a878ddb7a23..edc12e2091f851d0f7817f078712d58d7ff1e9b8 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -60,12 +60,14 @@ class VariableResponse { bool create_scope = false) : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { if (create_scope) { - local_scope_ = scope->NewTmpScope(); + local_scope_ = &scope->NewScope(); } } virtual ~VariableResponse() { - if (local_scope_) delete local_scope_; + if (local_scope_) { + scope_->DeleteScope(local_scope_); + } } int Parse(Source* source, const sendrecv::VariableMessage& meta) {