提交 3fea70a3 编写于 作者: M malin10

test=develop, bug fix

上级 fc8ba050
...@@ -507,6 +507,11 @@ void GeoCommunicator::MainThread() { ...@@ -507,6 +507,11 @@ void GeoCommunicator::MainThread() {
VLOG(3) << "wait for running"; VLOG(3) << "wait for running";
} }
for (auto &iter : send_varname_to_ctx_) {
splited_ids_vec_.insert(
std::pair<std::string, std::vector<SplitedSparseIds>>{
iter.first, std::vector<SplitedSparseIds>()});
}
while (running_) { while (running_) {
int meet = Meet(); int meet = Meet();
...@@ -528,9 +533,9 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -528,9 +533,9 @@ void GeoCommunicator::SendByCommunicator(int batches) {
int pserver_num = static_cast<int>(send_ctx.epmap.size()); int pserver_num = static_cast<int>(send_ctx.epmap.size());
auto &ids_queue = send_ids_to_queue_.at(var_name); auto &ids_queue = send_ids_to_queue_.at(var_name);
splited_ids_vec_.clear(); splited_ids_vec_[var_name].clear();
for (int i = 0; i < batches; ++i) { for (int i = 0; i < batches; ++i) {
splited_ids_vec_.push_back(*(ids_queue->Pop())); splited_ids_vec_[var_name].push_back(*(ids_queue->Pop()));
} }
if (send_ctx.is_sparse) { if (send_ctx.is_sparse) {
...@@ -544,6 +549,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -544,6 +549,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
}; };
tasks.emplace_back( tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task))); send_threadpool_->enqueue(std::move(send_recv_task)));
tasks[tasks.size() - 1].wait();
} }
} else { } else {
auto send_recv_task = [this, &var_name, &send_ctx] { auto send_recv_task = [this, &var_name, &send_ctx] {
...@@ -561,9 +567,9 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -561,9 +567,9 @@ void GeoCommunicator::SendByCommunicator(int batches) {
} }
} }
for (auto &task : tasks) { // for (auto &task : tasks) {
task.wait(); // task.wait();
} // }
} }
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
...@@ -576,10 +582,11 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -576,10 +582,11 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto endpoint = rpc_ctx.epmap[ep_idx]; auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size(); auto pserver_num = rpc_ctx.epmap.size();
int batches = static_cast<int>(splited_ids_vec_.size()); int batches = static_cast<int>(splited_ids_vec_[varname].size());
for (int i = 0; i < batches; ++i) { for (int i = 0; i < batches; ++i) {
std::copy(splited_ids_vec_[i].at(ep_idx).begin(), std::copy(splited_ids_vec_[varname][i].at(ep_idx).begin(),
splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids)); splited_ids_vec_[varname][i].at(ep_idx).end(),
back_inserter(ids));
} }
auto size = ids.size(); auto size = ids.size();
...@@ -701,7 +708,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -701,7 +708,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto train_id = recv_varname_to_ctx_.at(varname).trainer_id; auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx]; auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
auto splited_var_name = auto splited_var_name =
DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]); recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size(); auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
......
...@@ -444,19 +444,6 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -444,19 +444,6 @@ class GeoCommunicator : public AsyncCommunicator {
void InitDense(const std::string varname); void InitDense(const std::string varname);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
return send_name;
}
const std::string DeltaVarToVar(const std::string var_name) {
std::string origin_name = var_name;
origin_name.erase(origin_name.find(".delta"), 6);
const std::string param_name = origin_name;
return param_name;
}
private: private:
int trainers_; int trainers_;
std::string sparse_attrs_; std::string sparse_attrs_;
...@@ -476,7 +463,8 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -476,7 +463,8 @@ class GeoCommunicator : public AsyncCommunicator {
send_ids_to_queue_; send_ids_to_queue_;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_; std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::vector<SplitedSparseIds> splited_ids_vec_; std::unordered_map<std::string, std::vector<SplitedSparseIds>>
splited_ids_vec_;
}; };
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册