From fbf9564f6bebd3a2c181f6525f11e6bfec79fcbe Mon Sep 17 00:00:00 2001 From: 123malin Date: Tue, 24 Nov 2020 09:24:34 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.distributed.fleet=E3=80=91Optim?= =?UTF-8?q?ize=20ParameterServer's=20Async=20Mode=20(#28442)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test=develop, optimize global_step --- .../operators/distributed/communicator.cc | 184 +++++++++++++----- .../operators/distributed/communicator.h | 12 +- .../fleet/runtime/parameter_server_runtime.py | 1 + 3 files changed, 138 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 8fa6673e2a2..07427bb69d9 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -65,6 +65,7 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, } else { send_scope_.reset(new Scope()); for (auto &iter : send_varname_to_ctx_) { + if (iter.first == STEP_COUNTER && !need_global_step_) continue; send_varname_to_queue_[iter.first] = std::make_shared>>( send_queue_size_); @@ -108,21 +109,87 @@ void AsyncCommunicator::SendGlobalStep(int batches) { send_functor(ctx, *send_scope_, true, 1); } -void AsyncCommunicator::SendByCommunicator(int batches) { +void AsyncCommunicator::SendByCommunicator() { std::vector> task_futures; task_futures.reserve(send_varname_to_ctx_.size()); VLOG(3) << "run send graph"; + auto before_run_send_graph = GetCurrentUS(); for (auto &iter : send_varname_to_queue_) { auto &var_name = iter.first; auto &var_queue = iter.second; - auto send_task = [this, batches, &var_name, &var_queue] { + auto send_task = [this, &var_name, &var_queue] { + VLOG(3) << var_name << " merge and send; "; + std::vector> vars; + + int merged_var_num = 0; + int wait_times = 0; + while (merged_var_num < max_merge_var_num_) { + if (var_queue->Size() == 0) { + VLOG(4) << "wait_times -> " << wait_times; + if (wait_times >= send_wait_times_) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + + vars.push_back(var_queue->Pop()); + merged_var_num++; + } + } + auto before_merge = GetCurrentUS(); if (var_name == STEP_COUNTER) { + SendGlobalStep(merged_var_num); + auto after_merge = GetCurrentUS(); + VLOG(3) << "merge and send " << merged_var_num << " " << var_name + << " use time " << after_merge - before_merge; return; } - VLOG(3) << var_name << " merge and send"; + auto &ctx = send_varname_to_ctx_.at(var_name); + + MergeVars(var_name, vars, send_scope_.get(), ctx.merge_add); + auto after_merge = GetCurrentUS(); + VLOG(3) << "merge " << merged_var_num << " " << var_name << " use time " + << after_merge - before_merge; + + auto send_functor = distributed::ParameterSend(); + send_functor(ctx, *send_scope_, true, 1); + auto after_send = GetCurrentUS(); + VLOG(3) << "send " << var_name << " use time " + << after_send - after_merge; + }; + task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); + } + for (auto &task_f : task_futures) { + task_f.wait(); + } + auto after_run_send_graph = GetCurrentUS(); + + VLOG(3) << "run send graph use time " + << (after_run_send_graph - before_run_send_graph); +} + +void HalfAsyncCommunicator::SendByCommunicator() { + std::vector> task_futures; + task_futures.reserve(send_varname_to_ctx_.size()); + VLOG(3) << "run send graph"; + + int batches = BatchesCounter(); + if (batches <= 0) return; + + auto before_run_send_graph = GetCurrentUS(); + for (auto &iter : send_varname_to_queue_) { + auto &var_name = iter.first; + auto &var_queue = iter.second; + + auto send_task = [this, batches, &var_name, &var_queue] { + VLOG(3) << var_name << " merge and send; "; + auto before_task = GetCurrentUS(); std::vector> vars; vars.reserve(batches); @@ -130,6 +197,14 @@ void AsyncCommunicator::SendByCommunicator(int batches) { vars.push_back(var_queue->Pop()); } + if (var_name == STEP_COUNTER) { + SendGlobalStep(batches); + auto end_task = GetCurrentUS(); + VLOG(3) << "merge " << batches << " " << var_name << " use time " + << end_task - before_task; + return; + } + auto &ctx = send_varname_to_ctx_.at(var_name); auto before_merge = GetCurrentUS(); @@ -142,7 +217,20 @@ void AsyncCommunicator::SendByCommunicator(int batches) { send_functor(ctx, *send_scope_, true, 1); auto after_send = GetCurrentUS(); VLOG(3) << "send " << var_name << " use time " - << after_send - after_merge; + << after_send - before_task; + + if (var_name.rfind("@GRAD") != var_name.size() - 5) return; + + auto recv_param = var_name.substr(0, var_name.size() - 5); + if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end()) + return; + + auto recv_functor = distributed::ParameterRecv(); + recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_); + auto after_recv = GetCurrentUS(); + VLOG(3) << "recv " << recv_param << " use time " + << after_recv - after_send; + return; }; task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); } @@ -152,7 +240,7 @@ void AsyncCommunicator::SendByCommunicator(int batches) { auto after_run_send_graph = GetCurrentUS(); VLOG(3) << "run send graph use time " - << after_run_send_graph - before_run_send_graph; + << (after_run_send_graph - before_run_send_graph); } void AsyncCommunicator::MainThread() { @@ -164,20 +252,28 @@ void AsyncCommunicator::MainThread() { } while (running_) { - int batches = BatchesCounter(); - - if (batches > 0) { - SendGlobalStep(batches); - SendByCommunicator(batches); - BarrierSend(); - RecvByCommunicator(); - BarrierRecv(); - BarrierWeakUp(); - } else { - VLOG(1) << "get nothing from sending queue, will skip send/recv"; - } + SendByCommunicator(); + BarrierSend(); } - VLOG(1) << "communicator stopped, send thread exit"; + VLOG(3) << "communicator stopped, send thread exit"; +} + +void HalfAsyncCommunicator::MainThread() { + VLOG(3) << "MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } + VLOG(3) << "communicator stopped, send thread exit"; } void AsyncCommunicator::RecvByCommunicator() { @@ -193,10 +289,13 @@ void AsyncCommunicator::RecvNoBarrier() { for (auto &iter : recv_varname_to_ctx_) { auto recv_task = [this, &iter] { + auto before_task = GetCurrentUS(); auto &var_name = iter.first; - VLOG(4) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); recv_functor(iter.second, *recv_scope_); + auto end_task = GetCurrentUS(); + VLOG(1) << "recv var " << var_name << " use time " + << (end_task - before_task); }; task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); } @@ -206,37 +305,12 @@ void AsyncCommunicator::RecvNoBarrier() { } } -int AsyncCommunicator::BatchesCounter() { - auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER); - - size_t merged_var_num = 0; - size_t wait_times = 0; - - while (merged_var_num < static_cast(max_merge_var_num_)) { - if (step_queue->Size() == 0) { - VLOG(3) << "wait_times -> " << wait_times; - if (wait_times >= static_cast(send_wait_times_)) { - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - wait_times++; - continue; - } else { - step_queue->Pop(); - wait_times = 0; - merged_var_num++; - } - } - - return merged_var_num; -} - void AsyncCommunicator::Start() { - VLOG(1) << "Communicator start"; + VLOG(3) << "Communicator start"; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { - VLOG(1) << "start send thread and recv thread"; + VLOG(3) << "start send thread and recv thread"; waiting_ = true; running_ = true; BarrierTriggerReset(max_merge_var_num_); @@ -247,18 +321,18 @@ void AsyncCommunicator::Start() { } void AsyncCommunicator::Stop() { - VLOG(1) << "Communicator stop"; + VLOG(3) << "Communicator stop"; running_ = false; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { if (main_thread_) { - VLOG(1) << "stop send thread"; + VLOG(3) << "stop send thread"; main_thread_->join(); main_thread_.reset(nullptr); } } - VLOG(1) << "Communicator stop done"; + VLOG(3) << "Communicator stop done"; } void AsyncCommunicator::Send(const std::vector &var_names, @@ -271,6 +345,10 @@ void AsyncCommunicator::Send(const std::vector &var_names, platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); auto table_name = var_tables[0]; + + if (table_name == STEP_COUNTER && !need_global_step_) return; + + auto before_send_op = GetCurrentUS(); auto &queue = send_varname_to_queue_.at(table_name); if (table_name == STEP_COUNTER) { @@ -279,7 +357,6 @@ void AsyncCommunicator::Send(const std::vector &var_names, tensor->Resize(framework::make_ddim({1})); auto *out_d = tensor->mutable_data(platform::CPUPlace()); out_d[0] = 1; - VLOG(3) << "send to " << table_name << " with queue size " << queue->Size(); queue->Push(tmp_var); } else { PADDLE_ENFORCE_GE(var_names.size(), 1, @@ -295,21 +372,20 @@ void AsyncCommunicator::Send(const std::vector &var_names, auto tmp_var = std::make_shared(); if (var->IsType()) { framework::CopyVariable(*var, tmp_var.get()); - VLOG(3) << "send to " << table_name << " with queue size " - << queue->Size(); queue->Push(tmp_var); } else if (var->IsType()) { // push var into send queue by var_name auto var_name = var_names[0]; framework::CopyVariable(*var, tmp_var.get()); - VLOG(3) << "send to " << table_name << " with queue size " - << queue->Size(); queue->Push(tmp_var); } else { PADDLE_THROW(platform::errors::InvalidArgument( "unknown var type to copy, only support LoDTensor/SelectedRows")); } } + auto after_send_op = GetCurrentUS(); + VLOG(3) << "send to " << table_name << " with queue size " << queue->Size() + << ", use time " << (after_send_op - before_send_op); } void HalfAsyncCommunicator::Clean() { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 7c4910421f8..4be3253d392 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -302,16 +302,13 @@ class AsyncCommunicator : public Communicator { const std::vector &var_tables, const framework::Scope &scope) override; - virtual void SendByCommunicator(int batches); - + virtual void SendByCommunicator(); virtual void SendGlobalStep(int batches); virtual void RecvByCommunicator(); virtual void RecvNoBarrier(); - virtual int BatchesCounter(); - virtual void BarrierSend() {} virtual void BarrierRecv() {} @@ -359,6 +356,10 @@ class HalfAsyncCommunicator : public AsyncCommunicator { VLOG(0) << "HalfAsyncCommunicator Initialized"; } + void MainThread() override; + + void SendByCommunicator() override; + void Clean() override; void Barrier() override; @@ -438,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator { const std::vector &var_tables, const framework::Scope &scope) override; - void SendByCommunicator(int batches) { return; } + void SendByCommunicator() { return; } std::vector MergeSparseIds(const std::string &send_varname); @@ -475,6 +476,7 @@ class GeoCommunicator : public AsyncCommunicator { std::shared_ptr pserver_scope_; int send_var_nums_ = 0; + std::unordered_map> old_sparses_; std::unordered_map< diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 887209d9de2..782ba87e079 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -207,6 +207,7 @@ class ParameterServerRuntime(RuntimeBase): SyncStrategy, GeoStrategy trainer_config = self.async_strategy.get_trainer_runtime_config() + print(trainer_config) dist_strategy = self.context["valid_strategy"] launch_barrier = dist_strategy.a_sync_configs["launch_barrier"] -- GitLab