From ad9c8f6d2d23613bb2634170bc4ebf6768160520 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 9 Mar 2020 15:51:33 +0800 Subject: [PATCH] fix communicator when break under pyreder mode (#22911) * fix communicator when breaking under PyReader mode, test=develop * revert some vlog level to 0, test=develop --- .../operators/distributed/communicator.cc | 50 +++++++++++++------ .../operators/distributed/communicator.h | 4 ++ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 5099a767729..094062bb21e 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -213,7 +213,7 @@ void AsyncCommunicator::SendThread() { << after_run_send_graph - before_run_send_graph; Recv(); } - VLOG(0) << "communicator stopped, send thread exit"; + VLOG(1) << "communicator stopped, send thread exit"; } void AsyncCommunicator::RecvThread() { @@ -227,7 +227,7 @@ void AsyncCommunicator::RecvThread() { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } - VLOG(0) << "communicator stopped, recv thread exit"; + VLOG(1) << "communicator stopped, recv thread exit"; } void AsyncCommunicator::Recv() { @@ -267,7 +267,7 @@ void AsyncCommunicator::RecvAll() { } void AsyncCommunicator::Start() { - VLOG(0) << "Communicator start"; + VLOG(1) << "Communicator start"; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { @@ -284,7 +284,7 @@ void AsyncCommunicator::Start() { } void AsyncCommunicator::Stop() { - VLOG(0) << "Communicator stop"; + VLOG(1) << "Communicator stop"; running_ = false; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; @@ -300,7 +300,7 @@ void AsyncCommunicator::Stop() { recv_thread_.reset(nullptr); } } - VLOG(0) << "Communicator stop done"; + VLOG(1) << "Communicator stop done"; } void AsyncCommunicator::Send(const std::vector &var_names, @@ -385,11 +385,11 @@ void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, } void GeoSgdCommunicator::Start() { - VLOG(0) << "Geo Sgd Communicator start"; + VLOG(1) << "Geo Sgd Communicator start"; if (!communicator_) { VLOG(0) << "Geo Sgd Communicator is not inited, do nothing"; } else { - VLOG(0) << "start send thread "; + VLOG(1) << "start send thread "; running_ = true; // start send and recv thread send_thread_.reset( @@ -398,7 +398,7 @@ void GeoSgdCommunicator::Start() { } void GeoSgdCommunicator::Stop() { - VLOG(0) << "Geo Sgd Communicator stop"; + VLOG(1) << "Geo Sgd Communicator stop"; running_ = false; if (!communicator_) { VLOG(0) << "Geo Sgd Communicator is not inited, do nothing"; @@ -409,7 +409,7 @@ void GeoSgdCommunicator::Stop() { send_thread_.reset(nullptr); } } - VLOG(0) << "Geo Sgd Communicator stop done"; + VLOG(1) << "Geo Sgd Communicator stop done"; } void GeoSgdCommunicator::Send(const std::vector &sparse_var_names, @@ -463,7 +463,7 @@ void GeoSgdCommunicator::Send(const std::vector &sparse_var_names, } void GeoSgdCommunicator::SendThread() { - VLOG(0) << "SendThread start!"; + VLOG(1) << "SendThread start!"; auto before_run_training = GetCurrentUS(); while (running_) { @@ -1024,6 +1024,19 @@ HalfAsyncCommunicator::~HalfAsyncCommunicator() { if (consume_thread_) consume_thread_->join(); } +void HalfAsyncCommunicator::Clean() { + for (auto &iter : send_varname_to_queue_) { + auto &var_name = iter.first; + auto &var_queue = iter.second; + + while (var_queue->Size() > 0) { + var_queue->Pop(); + } + + VLOG(3) << "clean var: " << var_name << " done"; + } +} + void HalfAsyncCommunicator::ConsumeThread() { VLOG(3) << "ConsumeThread start!"; while (running_) { @@ -1099,7 +1112,10 @@ void HalfAsyncCommunicator::ConsumeThread() { BarrierRecv(); BarrierWeakUp(); } - VLOG(0) << "communicator stopped, send thread exit"; + + Clean(); + + VLOG(1) << "communicator stopped, send thread exit"; } void HalfAsyncCommunicator::Send(const std::vector &var_names, @@ -1146,6 +1162,12 @@ void HalfAsyncCommunicator::Recv() { void HalfAsyncCommunicator::Barrier() { barrier_counter_++; + + if (!running_) { + VLOG(3) << "Communicator is not running, release barrier"; + return; + } + { std::unique_lock lk(barrier_mutex_); barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); @@ -1171,7 +1193,7 @@ void HalfAsyncCommunicator::BarrierWeakUp() { } void HalfAsyncCommunicator::Start() { - VLOG(0) << "Communicator start"; + VLOG(1) << "Communicator start"; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { @@ -1185,7 +1207,7 @@ void HalfAsyncCommunicator::Start() { } void HalfAsyncCommunicator::Stop() { - VLOG(0) << "Communicator stop"; + VLOG(1) << "Communicator stop"; running_ = false; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; @@ -1196,7 +1218,7 @@ void HalfAsyncCommunicator::Stop() { consume_thread_.reset(nullptr); } } - VLOG(0) << "Communicator stop done"; + VLOG(1) << "Communicator stop done"; } void SyncCommunicator::BarrierSend() { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 0e241205557..2c504a27e57 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -183,6 +183,8 @@ class Communicator { virtual void Stop() = 0; virtual bool IsRunning() { return running_; } + virtual void Clean() {} + virtual void Send(const std::vector& var_names, const std::vector& var_tables, const framework::Scope& scope) = 0; @@ -309,6 +311,8 @@ class HalfAsyncCommunicator : public Communicator { void Start() override; void Stop() override; + void Clean() override; + void Send(const std::vector& var_names, const std::vector& var_tables, const framework::Scope& scope) override; -- GitLab