diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index a398fa8d12685c7e2dea57f1473ca1c9faceeeb4..7e574f29552d1a9d386fd8895b80a2dd05da5214 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -377,9 +377,8 @@ void SyncCommunicator::BarrierSend() { } for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE( - rets[i]->Wait(), 0U, - platform::errors::External("internal error in RPCClient")); + PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( + "internal error in RPCClient")); } VLOG(4) << "BarrierSend with SyncCommunicator"; @@ -397,9 +396,8 @@ void SyncCommunicator::BarrierRecv() { } for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE( - rets[i]->Wait(), 0U, - platform::errors::External("internal error in RPCClient")); + PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( + "internal error in RPCClient")); } VLOG(4) << "BarrierRecv with SyncCommunicator"; @@ -432,7 +430,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, } send_ids_to_queue_[varname] = - std::make_shared>>( + std::make_shared>>( send_queue_size_); } } @@ -489,12 +487,13 @@ void GeoCommunicator::Send(const std::vector &var_names, "Only LodTensor can be send in GeoCommunicator::Send")); } - auto pserver_num = send_varname_to_ctx_.at[table_name].epmap.size(); + auto pserver_num = send_varname_to_ctx_.at(table_name).epmap.size(); auto ids = std::make_shared(pserver_num); + auto &rows = var->Get().rows(); // split rows index into output sparse vars for (size_t i = 0; i < rows.size(); ++i) { auto ep_idx = rows[i] % pserver_num; - ids[ep_idx].add(rows[i]); + ids->at(ep_idx).insert(rows[i]); } queue->Push(ids); } @@ -526,7 +525,8 @@ void GeoCommunicator::SendByCommunicator(int batches) { for (auto &iter : send_varname_to_ctx_) { auto &var_name = iter.first; auto &send_ctx = iter.second; - auto &pserver_num = send_ctx.epmap.size(); + int pserver_num = static_cast(send_ctx.epmap.size()); + auto &ids_queue = send_ids_to_queue_.at(var_name); splited_ids_vec_.clear(); for (int i = 0; i < batches; ++i) { @@ -534,7 +534,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { } if (send_ctx.is_sparse) { - for (auto ep_idx = 0; ep_idx < pserver_num; ep_idx++) { + for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { auto send_recv_task = [this, ep_idx, &var_name] { if (var_name == STEP_COUNTER) { return; @@ -568,29 +568,50 @@ void GeoCommunicator::SendByCommunicator(int batches) { void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { std::vector ids; - auto &ids_queue = send_ids_to_queue_.at(varname); - auto send_varname = send_varname_to_ctx_.at[varname].splited_varnames[ep_idx]; - auto trainer_id = send_varname_to_ctx_.at[varname].trainer_id; - auto endpoint = send_varname_to_ctx_.at[varname].epmap[ep_idx]; + auto &rpc_ctx = send_varname_to_ctx_.at(varname); + VLOG(1) << rpc_ctx.print(); + auto send_varname = rpc_ctx.splited_varnames[ep_idx]; + auto trainer_id = rpc_ctx.trainer_id; + auto endpoint = rpc_ctx.epmap[ep_idx]; + auto pserver_num = rpc_ctx.epmap.size(); - for (int i = 0; i < splited_ids_vec_.size(); ++i) { - std::copy((*splited_ids_vec_[i])[ep_idx].begin(), - (*splited_ids_vec_[i])[ep_idx].end(), back_inserter(ids)); + int batches = static_cast(splited_ids_vec_.size()); + for (int i = 0; i < batches; ++i) { + std::copy(splited_ids_vec_[i].at(ep_idx).begin(), + splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids)); } auto size = ids.size(); std::set st(ids.begin(), ids.end()); ids.assign(st.begin(), st.end()); - VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size - << " set: " << ids.size(); + + std::stringstream list_str; + for (uint64_t i = 0; i < ids.size(); i++) { + list_str << ids[i] << ","; + } + VLOG(1) << "SendSparse receive var: " << send_varname << " unset: " << size + << " set: " << ids.size() << ": " << list_str.str(); if (ids.empty()) { LOG(WARNING) << "WARNING: GEO has nothing to send, return directly "; return; } + std::vector outs_rows_idx; + + if (!rpc_ctx.is_distributed) { + for (size_t i = 0; i < ids.size(); ++i) { + auto id = ids[i] / pserver_num; + outs_rows_idx.push_back(id); + } + } else { + for (size_t i = 0; i < ids.size(); ++i) { + outs_rows_idx.push_back(ids[i]); + } + } + auto *var_latest = recv_scope_->FindVar(varname); PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, @@ -603,8 +624,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto *var_delta = delta_scope_->Var(send_varname); auto *t_delta = var_delta->GetMutable(); - t_delta->set_height(ids.size()); - t_delta->mutable_rows()->assign(ids.begin(), ids.end()); + + t_delta->set_height(rpc_ctx.height_sections[ep_idx]); + t_delta->mutable_rows()->assign(outs_rows_idx.begin(), outs_rows_idx.end()); + auto *t_value = t_delta->mutable_value(); t_value->mutable_data( framework::make_ddim({static_cast(ids.size()), dims1}), @@ -625,6 +648,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { values[j][0]->data()); } + VLOG(1) << "begin to real send " << send_varname; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = @@ -632,7 +656,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, *delta_scope_.get(), send_varname); - ret.wait(); + VLOG(1) << "need to wait for send " << send_varname; + ret->Wait(); + VLOG(1) << "finish to send " << send_varname; } void GeoCommunicator::SendDense(const std::string &varname) { @@ -672,10 +698,11 @@ void GeoCommunicator::SendDense(const std::string &varname) { void GeoCommunicator::RecvByCommunicator() { return; } void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { - auto train_id = recv_varname_to_ctx_.at(var_name).trainer_id; - auto endpoint = recv_varname_to_ctx_.at(var_name).epmap[ep_idx]; + auto train_id = recv_varname_to_ctx_.at(varname).trainer_id; + auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx]; auto splited_var_name = - send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; + DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]); + auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size(); VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; @@ -683,6 +710,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance(train_id); + auto *var_psrever = pserver_scope_->Var(splited_var_name); auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv, *pserver_scope_.get(), splited_var_name, @@ -724,14 +752,15 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { auto blas = math::GetBlas(cpu_ctx); for (auto j = 0; j < static_cast(ids.size()); ++j) { + auto id = ids[j] * pserver_num + ep_idx; blas.VSUB(dims1, t_psrever.data() + j * dims1, old_values[j][0]->data(), v_delta.data() + j * dims1); - blas.VADD(dims1, t_latest->data() + ids[j] * dims1, - v_delta.data() + j * dims1, - t_latest->data() + ids[j] * dims1); + blas.VADD(dims1, t_latest->data() + id * dims1, + v_delta.data() + j * dims1, t_latest->data() + id * dims1); blas.VCOPY(dims1, t_psrever.data() + j * dims1, old_values[j][0]->data()); } + VLOG(1) << "receive finish"; } void GeoCommunicator::RecvDense(const std::string &varname) { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index b00225a9091483f388d285c431ce423187e4d542..c8e0dbae72f631374040510012ea00bb0ed597ac 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -282,7 +282,7 @@ class AsyncCommunicator : public Communicator { const RpcCtxMap &recv_varname_to_ctx, Scope *recv_scope) override; - void MainThread(); + virtual void MainThread(); void Send(const std::vector &var_names, const std::vector &var_tables, @@ -406,7 +406,7 @@ class GeoCommunicator : public AsyncCommunicator { void InitImpl(const RpcCtxMap &send_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx, Scope *recv_scope) override; - + void MainThread() override; void InitEnvs() { min_send_grad_num_before_recv_ = 0; @@ -444,6 +444,19 @@ class GeoCommunicator : public AsyncCommunicator { void InitDense(const std::string varname); + const std::string VarToDeltaVar(const std::string var_name) { + std::string delta_name = var_name; + const std::string send_name = delta_name.append(".delta"); + return send_name; + } + + const std::string DeltaVarToVar(const std::string var_name) { + std::string origin_name = var_name; + origin_name.erase(origin_name.find(".delta"), 6); + const std::string param_name = origin_name; + return param_name; + } + private: int trainers_; std::string sparse_attrs_;