From 613df7079abad037d237274ed72fedd2e4aab83f Mon Sep 17 00:00:00 2001 From: malin10 Date: Mon, 21 Sep 2020 21:52:47 +0800 Subject: [PATCH] test=develop, bug fix --- .../operators/distributed/communicator.cc | 47 ++++++++++++++----- .../operators/distributed/communicator.h | 1 + 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index c171835647f..84c05d2f501 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -425,6 +425,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, } else { auto &send_ctx = iter.second; + send_var_nums_ += send_ctx.splited_varnames.size(); if (!send_ctx.is_sparse) { continue; } @@ -462,16 +463,17 @@ void GeoCommunicator::Send(const std::vector &var_names, for (size_t i = 0; i < var_tables.size(); i++) { auto table_name = var_tables[i]; if (table_name == STEP_COUNTER) { - auto &queue = send_varname_to_queue_.at(table_name); - - auto tmp_var = std::make_shared(); - auto *tensor = tmp_var->GetMutable(); - tensor->Resize(framework::make_ddim({1})); - auto *out_d = tensor->mutable_data(platform::CPUPlace()); - out_d[0] = 1; - VLOG(3) << "send to " << table_name << " with queue size " - << queue->Size(); - queue->Push(tmp_var); + continue; + // auto &queue = send_varname_to_queue_.at(table_name); + + // auto tmp_var = std::make_shared(); + // auto *tensor = tmp_var->GetMutable(); + // tensor->Resize(framework::make_ddim({1})); + // auto *out_d = tensor->mutable_data(platform::CPUPlace()); + // out_d[0] = 1; + // VLOG(3) << "send to " << table_name << " with queue size " + // << queue->Size(); + // queue->Push(tmp_var); } else { auto splited_var_nums = recv_varname_to_ctx_[table_name].splited_varnames.size(); @@ -506,18 +508,22 @@ void GeoCommunicator::MainThread() { while (running_) { // int meet = Meet(); - VLOG(1) << "async_meet: " << meet; + // VLOG(1) << "async_meet: " << meet; // SendGlobalStep(meet); + auto before = GetCurrentUS(); SendByCommunicator(0); + auto after = GetCurrentUS(); + VLOG(0) << "finish one SendByCommunicator using " << (after - before); } VLOG(1) << "geo-communicator stopped, send thread exit"; } void GeoCommunicator::SendByCommunicator(int batches) { std::vector> tasks; - tasks.reserve(send_varname_to_ctx_.size()); + tasks.reserve(send_var_nums_); + auto before_send_by_communicator = GetCurrentUS(); size_t wait_times = 0; while (ids_send_vec_.size() < static_cast(max_merge_var_num_)) { VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size(); @@ -537,6 +543,13 @@ void GeoCommunicator::SendByCommunicator(int batches) { } if (ids_send_vec_.size() >= static_cast(max_merge_var_num_)) { + auto before_send_global_step = GetCurrentUS(); + VLOG(0) << "finish ins_send_vec using time " + << (before_send_global_step - before_send_by_communicator); + SendGlobalStep(max_merge_var_num_); + auto after_send_global_step = GetCurrentUS(); + VLOG(0) << "finish send global_step using " + << (after_send_global_step - before_send_global_step); for (auto &iter : send_varname_to_ctx_) { VLOG(1) << "debug " << iter.first; auto &var_name = iter.first; @@ -550,11 +563,20 @@ void GeoCommunicator::SendByCommunicator(int batches) { for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { auto send_recv_task = [this, ep_idx, &var_name] { + auto before_send_sparse = GetCurrentUS(); if (var_name == STEP_COUNTER) { return; } SendSparse(var_name, ep_idx); + auto after_send_sparse = GetCurrentUS(); RecvSparse(var_name, ep_idx); + auto after_recv_sparse = GetCurrentUS(); + VLOG(0) + << "send recv " + << send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx] + << " finish, using " << (after_send_sparse - before_send_sparse) + << " and " << (after_recv_sparse - after_send_sparse) + << "; total = " << (after_recv_sparse - before_send_sparse); }; tasks.emplace_back( send_threadpool_->enqueue(std::move(send_recv_task))); @@ -562,6 +584,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { } } else { auto send_recv_task = [this, &var_name, &send_ctx] { + return; if (var_name == STEP_COUNTER) { return; } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index d3a16aa0100..9526bd7cb0a 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -457,6 +457,7 @@ class GeoCommunicator : public AsyncCommunicator { // parameter on pserver std::shared_ptr pserver_scope_; + int send_var_nums_ = 0; std::unordered_map> old_sparses_; std::shared_ptr>> need_push_queue_; -- GitLab