From 63cd70a8b84905adc83d0fc082e4eaf15d91361b Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 8 Mar 2019 17:36:02 +0800 Subject: [PATCH] fix blocking problem --- .../operators/distributed/communicator.cc | 51 +++++++++++-------- .../operators/distributed/communicator.h | 38 +++++++------- .../operators/distributed/parameter_recv.cc | 2 + .../operators/distributed_ops/send_op.cc | 13 +++-- 4 files changed, 60 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index f5d274b66d9..a7bce26234d 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -75,10 +75,11 @@ void Communicator::SendThread() { while (running_) { std::vector> task_futures; task_futures.reserve(send_varname_to_ctx_.size()); + VLOG(3) << "run send graph"; for (auto &iter : send_varname_to_queue_) { auto &var_name = iter.first; auto &var_queue = iter.second; - if (var_queue->NotEmpty()) { // will block if queue is empty + if (var_queue->Size() > 0) { auto send_task = [this, &var_name, &var_queue] { VLOG(3) << "merge var " << var_name << " and send"; std::vector> vars; @@ -96,33 +97,41 @@ void Communicator::SendThread() { }; task_futures.emplace_back( send_threadpool_->enqueue(std::move(send_task))); + } else { + VLOG(3) << var_name << " queue empty"; } } for (auto &task_f : task_futures) { task_f.wait(); } + VLOG(3) << "run send graph done"; + RecvAll(); } } +void Communicator::RecvAll() { + VLOG(3) << "parallel run recv graph"; + std::vector> task_futures; + task_futures.reserve(recv_varname_to_ctx_.size()); + for (auto &iter : recv_varname_to_ctx_) { + auto recv_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(3) << "recv var " << var_name; + auto recv_functor = distributed::ParameterRecv(); + recv_functor(iter.second, *recv_scope_); + }; + task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); + } + for (auto &task : task_futures) { + task.wait(); + } + VLOG(3) << "run recv graph done"; +} + void Communicator::RecvThread() { VLOG(3) << "RecvThread start!"; while (running_) { - // parallel run recv graph - std::vector> task_futures; - task_futures.reserve(recv_varname_to_ctx_.size()); - for (auto &iter : recv_varname_to_ctx_) { - auto recv_task = [this, &iter] { - auto &var_name = iter.first; - VLOG(3) << "recv var " << var_name; - auto recv_functor = distributed::ParameterRecv(); - recv_functor(iter.second, *recv_scope_); - }; - task_futures.emplace_back( - recv_threadpool_->enqueue(std::move(recv_task))); - } - for (auto &task : task_futures) { - task.wait(); - } + RecvAll(); // TODO(qiao) need to be configuable std::this_thread::sleep_for(std::chrono::milliseconds(200)); } @@ -136,7 +145,9 @@ void Communicator::Send(const std::string &var_name, PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); auto tmp_grad_var = std::make_shared(); framework::CopyVariable(*grad_var, tmp_grad_var.get()); - send_varname_to_queue_[var_name]->Push(tmp_grad_var); + auto &queue = send_varname_to_queue_.at(var_name); + VLOG(3) << "send " << var_name << " queue size " << queue->Size(); + queue->Push(tmp_grad_var); } Communicator *Communicator::GetInstance() { return communicator_.get(); } @@ -146,8 +157,8 @@ void Communicator::Start() { // start send and recv thread send_thread_.reset( new std::thread(std::bind(&Communicator::SendThread, this))); - recv_thread_.reset( - new std::thread(std::bind(&Communicator::RecvThread, this))); + // recv_thread_.reset( + // new std::thread(std::bind(&Communicator::RecvThread, this))); } } // namespace distributed diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index c93ad02555e..3c98b36b747 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -43,37 +43,36 @@ class BlockingQueue { } bool Push(const T& elem) { - std::unique_lock lock(mutex_); - send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - PADDLE_ENFORCE_LT(queue_.size(), capacity_); - queue_.push_back(elem); - recv_cv_.notify_one(); + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + PADDLE_ENFORCE_LT(queue_.size(), capacity_); + queue_.push_back(elem); + } + cv_.notify_one(); return true; } bool Push(T&& elem) { - std::unique_lock lock(mutex_); - send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - PADDLE_ENFORCE_LT(queue_.size(), capacity_); - queue_.emplace_back(std::move(elem)); - recv_cv_.notify_one(); + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + PADDLE_ENFORCE_LT(queue_.size(), capacity_); + queue_.emplace_back(std::move(elem)); + } + cv_.notify_one(); return true; } T Pop() { std::unique_lock lock(mutex_); - recv_cv_.wait(lock, [=] { return !queue_.empty(); }); + cv_.wait(lock, [=] { return !queue_.empty(); }); T rc(std::move(queue_.front())); queue_.pop_front(); + cv_.notify_one(); return rc; } - bool NotEmpty() { - std::unique_lock lock(mutex_); - recv_cv_.wait(lock, [=] { return !queue_.empty(); }); - return true; - } - size_t Cap() const { std::lock_guard lock(mutex_); return capacity_; @@ -89,8 +88,7 @@ class BlockingQueue { std::deque queue_; mutable std::mutex mutex_; - std::condition_variable recv_cv_; - std::condition_variable send_cv_; + std::condition_variable cv_; }; using RpcCtxMap = std::unordered_map; @@ -127,6 +125,8 @@ class Communicator { void Send(const std::string& var_name, const framework::Scope& scope); private: + // recv all parameter + void RecvAll(); void SendThread(); void RecvThread(); diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index fecc76955de..c3238f28f63 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -41,6 +41,7 @@ using DDim = framework::DDim; template void ParameterRecv::operator()(const RpcContext &rpc_ctx, const framework::Scope &scope) { + VLOG(3) << "ParameterRecv in"; framework::Scope *local_scope = scope.NewTmpScope(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -90,6 +91,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, } delete local_scope; + VLOG(3) << "ParameterRecv out"; } template struct ParameterRecv; diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 347395b7ccd..67de7b4185b 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -48,12 +48,15 @@ class SendOp : public framework::OperatorBase { if (send_varnames.size() > 0) { PADDLE_ENFORCE_EQ(ins.size(), 1, ""); - // auto send_functor = distributed::ParameterSend(); - // auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, - // epmap, - // height_sections); - // send_functor(rpc_ctx, scope, static_cast(sync_send)); + /* + auto send_functor = distributed::ParameterSend(); + auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, + height_sections); + send_functor(rpc_ctx, scope, static_cast(sync_send)); + */ + VLOG(3) << "send " << ins[0]; distributed::Communicator::GetInstance()->Send(ins[0], scope); + VLOG(3) << "send " << ins[0] << " done"; } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); -- GitLab