提交 613df707 编写于 作者: M malin10

test=develop, bug fix

上级 78b603b2
......@@ -425,6 +425,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} else {
auto &send_ctx = iter.second;
send_var_nums_ += send_ctx.splited_varnames.size();
if (!send_ctx.is_sparse) {
continue;
}
......@@ -462,16 +463,17 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
for (size_t i = 0; i < var_tables.size(); i++) {
auto table_name = var_tables[i];
if (table_name == STEP_COUNTER) {
auto &queue = send_varname_to_queue_.at(table_name);
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
VLOG(3) << "send to " << table_name << " with queue size "
<< queue->Size();
queue->Push(tmp_var);
continue;
// auto &queue = send_varname_to_queue_.at(table_name);
// auto tmp_var = std::make_shared<Variable>();
// auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
// tensor->Resize(framework::make_ddim({1}));
// auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
// out_d[0] = 1;
// VLOG(3) << "send to " << table_name << " with queue size "
// << queue->Size();
// queue->Push(tmp_var);
} else {
auto splited_var_nums =
recv_varname_to_ctx_[table_name].splited_varnames.size();
......@@ -506,18 +508,22 @@ void GeoCommunicator::MainThread() {
while (running_) {
// int meet = Meet();
VLOG(1) << "async_meet: " << meet;
// VLOG(1) << "async_meet: " << meet;
// SendGlobalStep(meet);
auto before = GetCurrentUS();
SendByCommunicator(0);
auto after = GetCurrentUS();
VLOG(0) << "finish one SendByCommunicator using " << (after - before);
}
VLOG(1) << "geo-communicator stopped, send thread exit";
}
void GeoCommunicator::SendByCommunicator(int batches) {
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
tasks.reserve(send_var_nums_);
auto before_send_by_communicator = GetCurrentUS();
size_t wait_times = 0;
while (ids_send_vec_.size() < static_cast<size_t>(max_merge_var_num_)) {
VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size();
......@@ -537,6 +543,13 @@ void GeoCommunicator::SendByCommunicator(int batches) {
}
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator);
SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first;
auto &var_name = iter.first;
......@@ -550,11 +563,20 @@ void GeoCommunicator::SendByCommunicator(int batches) {
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] {
auto before_send_sparse = GetCurrentUS();
if (var_name == STEP_COUNTER) {
return;
}
SendSparse(var_name, ep_idx);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(0)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using " << (after_send_sparse - before_send_sparse)
<< " and " << (after_recv_sparse - after_send_sparse)
<< "; total = " << (after_recv_sparse - before_send_sparse);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
......@@ -562,6 +584,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
return;
if (var_name == STEP_COUNTER) {
return;
}
......
......@@ -457,6 +457,7 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
int send_var_nums_ = 0;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::shared_ptr<BlockingQueue<std::shared_ptr<SparseIdsMap>>>
need_push_queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册