提交 b3526fb4 编写于 作者: M malin10

tmp

上级 613df707
......@@ -458,7 +458,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
// var_tables.size(), 1,
// platform::errors::InvalidArgument("var_tables.size() == 1 is
// permitted"));
auto before_send = GetCurrentUS();
std::shared_ptr<SparseIdsMap> ids_table = std::make_shared<SparseIdsMap>();
for (size_t i = 0; i < var_tables.size(); i++) {
auto table_name = var_tables[i];
......@@ -494,7 +494,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
}
}
need_push_queue_->Push(ids_table);
VLOG(1) << "run send_op finish.";
auto after_send = GetCurrentUS();
VLOG(0) << "run send_op finish. using " << (after_send - before_send);
}
void GeoCommunicator::MainThread() {
......@@ -506,106 +507,98 @@ void GeoCommunicator::MainThread() {
}
while (running_) {
// int meet = Meet();
// VLOG(1) << "async_meet: " << meet;
// SendGlobalStep(meet);
auto before = GetCurrentUS();
SendByCommunicator(0);
auto after = GetCurrentUS();
VLOG(0) << "finish one SendByCommunicator using " << (after - before);
}
VLOG(1) << "geo-communicator stopped, send thread exit";
}
void GeoCommunicator::SendByCommunicator(int batches) {
std::vector<std::future<void>> tasks;
tasks.reserve(send_var_nums_);
auto before_send_by_communicator = GetCurrentUS();
size_t wait_times = 0;
while (ids_send_vec_.size() < static_cast<size_t>(max_merge_var_num_)) {
VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size();
if (need_push_queue_->Size() > 0) {
wait_times = 0;
ids_send_vec_.push_back(*(need_push_queue_->Pop()));
VLOG(1) << "ids_send_vec_ pushed";
} else if (need_push_queue_->Size() == 0) {
VLOG(1) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
std::vector<std::future<void>> tasks;
tasks.reserve(send_var_nums_);
auto before_send_by_communicator = GetCurrentUS();
size_t wait_times = 0;
while (ids_send_vec_.size() < static_cast<size_t>(max_merge_var_num_)) {
VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size();
if (need_push_queue_->Size() > 0) {
wait_times = 0;
ids_send_vec_.push_back(*(need_push_queue_->Pop()));
VLOG(1) << "ids_send_vec_ pushed";
} else if (need_push_queue_->Size() == 0) {
VLOG(1) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
}
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator);
SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first;
auto &var_name = iter.first;
auto &send_ctx = iter.second;
int pserver_num = static_cast<int>(send_ctx.epmap.size());
if (send_ctx.is_sparse) {
if (var_name == STEP_COUNTER) {
continue;
}
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator)
<< "; send_var_nums_ = " << send_var_nums_;
SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first;
auto &var_name = iter.first;
auto &send_ctx = iter.second;
int pserver_num = static_cast<int>(send_ctx.epmap.size());
if (send_ctx.is_sparse) {
if (var_name == STEP_COUNTER) {
continue;
}
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] {
auto before_send_sparse = GetCurrentUS();
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] {
auto before_send_sparse = GetCurrentUS();
if (var_name == STEP_COUNTER) {
return;
}
SendSparse(var_name, ep_idx);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(0)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using "
<< (after_send_sparse - before_send_sparse) << " and "
<< (after_recv_sparse - after_send_sparse)
<< "; total = " << (after_recv_sparse - before_send_sparse);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
// tasks[tasks.size() - 1].wait();
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
return;
if (var_name == STEP_COUNTER) {
return;
}
SendSparse(var_name, ep_idx);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(0)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using " << (after_send_sparse - before_send_sparse)
<< " and " << (after_recv_sparse - after_send_sparse)
<< "; total = " << (after_recv_sparse - before_send_sparse);
VLOG(1) << "send dense " << var_name << " begin";
SendDense(var_name);
VLOG(1) << "send dense " << var_name << " done";
VLOG(1) << "recv dense " << var_name << " begin";
RecvDense(var_name);
VLOG(1) << "recv dense " << var_name << " done";
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
// tasks[tasks.size() - 1].wait();
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
return;
if (var_name == STEP_COUNTER) {
return;
}
VLOG(1) << "send dense " << var_name << " begin";
SendDense(var_name);
VLOG(1) << "send dense " << var_name << " done";
VLOG(1) << "recv dense " << var_name << " begin";
RecvDense(var_name);
VLOG(1) << "recv dense " << var_name << " done";
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
}
for (auto &task : tasks) {
task.wait();
}
for (auto &task : tasks) {
task.wait();
}
ids_send_vec_.clear();
VLOG(1) << "Finish SendByCommunicator";
ids_send_vec_.clear();
auto finish_one_comm = GetCurrentUS();
VLOG(0) << "Finish SendByCommunicator "
<< (finish_one_comm - after_send_global_step);
}
}
}
......
......@@ -424,7 +424,7 @@ class GeoCommunicator : public AsyncCommunicator {
const std::vector<std::string> &var_tables,
const framework::Scope &scope) override;
void SendByCommunicator(int batches) override;
// void SendByCommunicator(int batches) override;
void SendSparse(const std::string &varname, int ep_idx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册