提交 3fea70a3 编写于 作者: M malin10

test=develop, bug fix

上级 fc8ba050
......@@ -507,6 +507,11 @@ void GeoCommunicator::MainThread() {
VLOG(3) << "wait for running";
}
for (auto &iter : send_varname_to_ctx_) {
splited_ids_vec_.insert(
std::pair<std::string, std::vector<SplitedSparseIds>>{
iter.first, std::vector<SplitedSparseIds>()});
}
while (running_) {
int meet = Meet();
......@@ -528,9 +533,9 @@ void GeoCommunicator::SendByCommunicator(int batches) {
int pserver_num = static_cast<int>(send_ctx.epmap.size());
auto &ids_queue = send_ids_to_queue_.at(var_name);
splited_ids_vec_.clear();
splited_ids_vec_[var_name].clear();
for (int i = 0; i < batches; ++i) {
splited_ids_vec_.push_back(*(ids_queue->Pop()));
splited_ids_vec_[var_name].push_back(*(ids_queue->Pop()));
}
if (send_ctx.is_sparse) {
......@@ -544,6 +549,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
tasks[tasks.size() - 1].wait();
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
......@@ -561,9 +567,9 @@ void GeoCommunicator::SendByCommunicator(int batches) {
}
}
for (auto &task : tasks) {
task.wait();
}
// for (auto &task : tasks) {
// task.wait();
// }
}
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
......@@ -576,10 +582,11 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size();
int batches = static_cast<int>(splited_ids_vec_.size());
int batches = static_cast<int>(splited_ids_vec_[varname].size());
for (int i = 0; i < batches; ++i) {
std::copy(splited_ids_vec_[i].at(ep_idx).begin(),
splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids));
std::copy(splited_ids_vec_[varname][i].at(ep_idx).begin(),
splited_ids_vec_[varname][i].at(ep_idx).end(),
back_inserter(ids));
}
auto size = ids.size();
......@@ -701,7 +708,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
auto splited_var_name =
DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]);
recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
......
......@@ -444,19 +444,6 @@ class GeoCommunicator : public AsyncCommunicator {
void InitDense(const std::string varname);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
return send_name;
}
const std::string DeltaVarToVar(const std::string var_name) {
std::string origin_name = var_name;
origin_name.erase(origin_name.find(".delta"), 6);
const std::string param_name = origin_name;
return param_name;
}
private:
int trainers_;
std::string sparse_attrs_;
......@@ -476,7 +463,8 @@ class GeoCommunicator : public AsyncCommunicator {
send_ids_to_queue_;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::vector<SplitedSparseIds> splited_ids_vec_;
std::unordered_map<std::string, std::vector<SplitedSparseIds>>
splited_ids_vec_;
};
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册