diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 7e574f29552d1a9d386fd8895b80a2dd05da5214..910c6bdafda6c68d59fcfa8fdcdce6b4256d74ed 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -507,6 +507,11 @@ void GeoCommunicator::MainThread() { VLOG(3) << "wait for running"; } + for (auto &iter : send_varname_to_ctx_) { + splited_ids_vec_.insert( + std::pair>{ + iter.first, std::vector()}); + } while (running_) { int meet = Meet(); @@ -528,9 +533,9 @@ void GeoCommunicator::SendByCommunicator(int batches) { int pserver_num = static_cast(send_ctx.epmap.size()); auto &ids_queue = send_ids_to_queue_.at(var_name); - splited_ids_vec_.clear(); + splited_ids_vec_[var_name].clear(); for (int i = 0; i < batches; ++i) { - splited_ids_vec_.push_back(*(ids_queue->Pop())); + splited_ids_vec_[var_name].push_back(*(ids_queue->Pop())); } if (send_ctx.is_sparse) { @@ -544,6 +549,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { }; tasks.emplace_back( send_threadpool_->enqueue(std::move(send_recv_task))); + tasks[tasks.size() - 1].wait(); } } else { auto send_recv_task = [this, &var_name, &send_ctx] { @@ -561,9 +567,9 @@ void GeoCommunicator::SendByCommunicator(int batches) { } } - for (auto &task : tasks) { - task.wait(); - } + // for (auto &task : tasks) { + // task.wait(); + // } } void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { @@ -576,10 +582,11 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { auto endpoint = rpc_ctx.epmap[ep_idx]; auto pserver_num = rpc_ctx.epmap.size(); - int batches = static_cast(splited_ids_vec_.size()); + int batches = static_cast(splited_ids_vec_[varname].size()); for (int i = 0; i < batches; ++i) { - std::copy(splited_ids_vec_[i].at(ep_idx).begin(), - splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids)); + std::copy(splited_ids_vec_[varname][i].at(ep_idx).begin(), + splited_ids_vec_[varname][i].at(ep_idx).end(), + back_inserter(ids)); } auto size = ids.size(); @@ -701,7 +708,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { auto train_id = recv_varname_to_ctx_.at(varname).trainer_id; auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx]; auto splited_var_name = - DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]); + recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size(); VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index c8e0dbae72f631374040510012ea00bb0ed597ac..7c601cae01ffcf7f9f7257cc59101e284d774f35 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -444,19 +444,6 @@ class GeoCommunicator : public AsyncCommunicator { void InitDense(const std::string varname); - const std::string VarToDeltaVar(const std::string var_name) { - std::string delta_name = var_name; - const std::string send_name = delta_name.append(".delta"); - return send_name; - } - - const std::string DeltaVarToVar(const std::string var_name) { - std::string origin_name = var_name; - origin_name.erase(origin_name.find(".delta"), 6); - const std::string param_name = origin_name; - return param_name; - } - private: int trainers_; std::string sparse_attrs_; @@ -476,7 +463,8 @@ class GeoCommunicator : public AsyncCommunicator { send_ids_to_queue_; std::unordered_map> old_sparses_; - std::vector splited_ids_vec_; + std::unordered_map> + splited_ids_vec_; }; } // namespace distributed