From a87a958b733d547244da334ca1590ac71123a90d Mon Sep 17 00:00:00 2001 From: malin10 Date: Fri, 18 Sep 2020 16:21:07 +0800 Subject: [PATCH] test=develop, bug fix --- .../operators/distributed/communicator.cc | 153 +++++++++++------- .../operators/distributed/communicator.h | 15 +- 2 files changed, 109 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index b2cc9390fa2..a398fa8d126 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/distributed/communicator.h" + #include #include + #include #include // NOLINT #include #include // NOLINT #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor_util.h" @@ -374,8 +377,9 @@ void SyncCommunicator::BarrierSend() { } for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( - "internal error in RPCClient")); + PADDLE_ENFORCE_NE( + rets[i]->Wait(), 0U, + platform::errors::External("internal error in RPCClient")); } VLOG(4) << "BarrierSend with SyncCommunicator"; @@ -393,8 +397,9 @@ void SyncCommunicator::BarrierRecv() { } for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( - "internal error in RPCClient")); + PADDLE_ENFORCE_NE( + rets[i]->Wait(), 0U, + platform::errors::External("internal error in RPCClient")); } VLOG(4) << "BarrierRecv with SyncCommunicator"; @@ -484,13 +489,36 @@ void GeoCommunicator::Send(const std::vector &var_names, "Only LodTensor can be send in GeoCommunicator::Send")); } - std::vector ids; - auto &rows = var->Get().rows(); - ids.assign(rows.begin(), rows.end()); + auto pserver_num = send_varname_to_ctx_.at[table_name].epmap.size(); + auto ids = std::make_shared(pserver_num); + // split rows index into output sparse vars + for (size_t i = 0; i < rows.size(); ++i) { + auto ep_idx = rows[i] % pserver_num; + ids[ep_idx].add(rows[i]); + } queue->Push(ids); } } +void GeoCommunicator::MainThread() { + VLOG(3) << "MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + int meet = Meet(); + + VLOG(1) << "async_meet: " << meet; + + SendGlobalStep(meet); + SendByCommunicator(meet); + } + VLOG(1) << "geo-communicator stopped, send thread exit"; +} + void GeoCommunicator::SendByCommunicator(int batches) { std::vector> tasks; tasks.reserve(send_varname_to_ctx_.size()); @@ -498,21 +526,39 @@ void GeoCommunicator::SendByCommunicator(int batches) { for (auto &iter : send_varname_to_ctx_) { auto &var_name = iter.first; auto &send_ctx = iter.second; + auto &pserver_num = send_ctx.epmap.size(); - auto send_task = [this, batches, &var_name, &send_ctx] { - if (var_name == STEP_COUNTER) { - return; - } + splited_ids_vec_.clear(); + for (int i = 0; i < batches; ++i) { + splited_ids_vec_.push_back(*(ids_queue->Pop())); + } - if (send_ctx.is_sparse) { - SendSparse(var_name, batches); - } else { + if (send_ctx.is_sparse) { + for (auto ep_idx = 0; ep_idx < pserver_num; ep_idx++) { + auto send_recv_task = [this, ep_idx, &var_name] { + if (var_name == STEP_COUNTER) { + return; + } + SendSparse(var_name, ep_idx); + RecvSparse(var_name, ep_idx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } else { + auto send_recv_task = [this, &var_name, &send_ctx] { + if (var_name == STEP_COUNTER) { + return; + } VLOG(1) << "send dense " << var_name << " begin"; SendDense(var_name); VLOG(1) << "send dense " << var_name << " done"; - } - }; - tasks.emplace_back(send_threadpool_->enqueue(std::move(send_task))); + VLOG(1) << "recv dense " << var_name << " begin"; + RecvDense(var_name); + VLOG(1) << "recv dense " << var_name << " done"; + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } } for (auto &task : tasks) { @@ -520,13 +566,17 @@ void GeoCommunicator::SendByCommunicator(int batches) { } } -void GeoCommunicator::SendSparse(const std::string &varname, int batches) { +void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { std::vector ids; auto &ids_queue = send_ids_to_queue_.at(varname); - for (int i = 0; i < batches; ++i) { - auto pop_ids = ids_queue->Pop(); - std::copy(pop_ids.begin(), pop_ids.end(), back_inserter(ids)); + auto send_varname = send_varname_to_ctx_.at[varname].splited_varnames[ep_idx]; + auto trainer_id = send_varname_to_ctx_.at[varname].trainer_id; + auto endpoint = send_varname_to_ctx_.at[varname].epmap[ep_idx]; + + for (int i = 0; i < splited_ids_vec_.size(); ++i) { + std::copy((*splited_ids_vec_[i])[ep_idx].begin(), + (*splited_ids_vec_[i])[ep_idx].end(), back_inserter(ids)); } auto size = ids.size(); @@ -551,7 +601,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) { auto dims1 = t_latest.dims()[1]; auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto *var_delta = delta_scope_->Var(varname); + auto *var_delta = delta_scope_->Var(send_varname); auto *t_delta = var_delta->GetMutable(); t_delta->set_height(ids.size()); t_delta->mutable_rows()->assign(ids.begin(), ids.end()); @@ -575,9 +625,14 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) { values[j][0]->data()); } - auto &ctx = send_varname_to_ctx_.at(varname); - auto send = distributed::ParameterSend(); - send(ctx, *delta_scope_, true, 1); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(trainer_id); + + auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, + *delta_scope_.get(), send_varname); + ret.wait(); } void GeoCommunicator::SendDense(const std::string &varname) { @@ -614,39 +669,29 @@ void GeoCommunicator::SendDense(const std::string &varname) { send(ctx, *delta_scope_, true, 1); } -void GeoCommunicator::RecvByCommunicator() { - std::vector> tasks; - tasks.reserve(recv_varname_to_ctx_.size()); +void GeoCommunicator::RecvByCommunicator() { return; } - for (auto &iter : recv_varname_to_ctx_) { - auto &var_name = iter.first; - auto &recv_ctx = iter.second; +void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { + auto train_id = recv_varname_to_ctx_.at(var_name).trainer_id; + auto endpoint = recv_varname_to_ctx_.at(var_name).epmap[ep_idx]; + auto splited_var_name = + send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; - auto recv_task = [this, &var_name, &recv_ctx] { - if (recv_ctx.is_sparse) { - RecvSparse(var_name); - } else { - VLOG(1) << "recv dense " << var_name << " begin"; - RecvDense(var_name); - VLOG(1) << "recv dense " << var_name << " done"; - } - }; - tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); - } - for (auto &task : tasks) { - task.wait(); - } -} + VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; -void GeoCommunicator::RecvSparse(const std::string &varname) { - VLOG(1) << "RecvSparse receive var: " << varname; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace()); + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(train_id); + auto *var_psrever = pserver_scope_->Var(splited_var_name); + auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv, + *pserver_scope_.get(), splited_var_name, + splited_var_name, splited_var_name); + handle->Wait(); - auto *var_latest = recv_scope_->FindVar(varname); - auto *var_psrever = pserver_scope_->Var(varname); + VLOG(1) << "Finish to RecvSparse receive var: " << splited_var_name; - auto &ctx = recv_varname_to_ctx_.at(varname); - auto recv = distributed::ParameterRecv(); - recv(ctx, *pserver_scope_, true); + auto *var_latest = recv_scope_->FindVar(varname); PADDLE_ENFORCE_EQ( var_psrever->IsInitialized(), true, @@ -657,7 +702,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) { ids.assign(var_psrever->Get().rows().begin(), var_psrever->Get().rows().end()); - VLOG(1) << "RecvSparse receive var: " << varname + VLOG(1) << "RecvSparse receive var: " << splited_var_name << " ids Size: " << ids.size(); auto t_psrever = var_psrever->Get().value(); diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 2f6da150d1e..b00225a9091 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include + #include #include #include @@ -25,8 +26,8 @@ limitations under the License. */ #include #include #include -#include "gflags/gflags.h" +#include "gflags/gflags.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/distributed/communicator_common.h" @@ -250,6 +251,8 @@ class Communicator { std::unordered_map envs; }; +using SplitedSparseIds = std::vector>; + class AsyncCommunicator : public Communicator { public: AsyncCommunicator() : Communicator() {} @@ -423,7 +426,7 @@ class GeoCommunicator : public AsyncCommunicator { void SendByCommunicator(int batches) override; - void SendSparse(const std::string &varname, int batches); + void SendSparse(const std::string &varname, int ep_idx); void SendDense(const std::string &varname); @@ -431,7 +434,7 @@ class GeoCommunicator : public AsyncCommunicator { void RecvByCommunicator() override; - void RecvSparse(const std::string &varname); + void RecvSparse(const std::string &varname, int ep_idx); void RecvDense(const std::string &varname); @@ -454,11 +457,13 @@ class GeoCommunicator : public AsyncCommunicator { // parameter on pserver std::shared_ptr pserver_scope_; - std::unordered_map>>> + std::unordered_map< + std::string, + std::shared_ptr>>> send_ids_to_queue_; std::unordered_map> old_sparses_; + std::vector splited_ids_vec_; }; } // namespace distributed -- GitLab