diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index b2cc9390fa2267404ac246c6b36800833d0dd679..194731d631cfeae1a14092b7fd4ab85ce9d02f24 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -157,16 +157,18 @@ void AsyncCommunicator::MainThread() { } while (running_) { - int meet = Meet(); - - VLOG(1) << "async_meet: " << meet; - - SendGlobalStep(meet); - SendByCommunicator(meet); - BarrierSend(); - RecvByCommunicator(); - BarrierRecv(); - BarrierWeakUp(); + int batches = BatchesCounter(); + + if (batches > 0) { + SendGlobalStep(batches); + SendByCommunicator(batches); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } else { + VLOG(1) << "get nothing from sending queue, will skip send/recv"; + } } VLOG(1) << "communicator stopped, send thread exit"; } @@ -197,7 +199,7 @@ void AsyncCommunicator::RecvNoBarrier() { } } -int AsyncCommunicator::Meet() { +int AsyncCommunicator::BatchesCounter() { auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER); size_t merged_var_num = 0; @@ -316,7 +318,7 @@ void HalfAsyncCommunicator::Clean() { } } -int HalfAsyncCommunicator::Meet() { +int HalfAsyncCommunicator::BatchesCounter() { while (running_) { if (barrier_counter_.load() >= barrier_trigger_.load() && barrier_trigger_.load() != 0) { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 2f6da150d1e1375c332f7e55ea5b16c07f067a40..98a2aba2ec2c25503636753e583656da4eabe869 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -293,7 +293,7 @@ class AsyncCommunicator : public Communicator { virtual void RecvNoBarrier(); - virtual int Meet(); + virtual int BatchesCounter(); virtual void BarrierSend() {} @@ -350,7 +350,7 @@ class HalfAsyncCommunicator : public AsyncCommunicator { void BarrierTriggerReset(int initial_val) override; - int Meet(); + int BatchesCounter(); void BarrierWeakUp();