未验证 提交 99cf2e61 编写于 作者: T tangwei12 提交者: GitHub

fix communicator when break under pyreder mode (#22911) (#22919)

* fix communicator when breaking under PyReader mode, test=develop
* revert some vlog level to 0, test=develop
上级 f3ce7dda
...@@ -213,7 +213,7 @@ void AsyncCommunicator::SendThread() { ...@@ -213,7 +213,7 @@ void AsyncCommunicator::SendThread() {
<< after_run_send_graph - before_run_send_graph; << after_run_send_graph - before_run_send_graph;
Recv(); Recv();
} }
VLOG(0) << "communicator stopped, send thread exit"; VLOG(1) << "communicator stopped, send thread exit";
} }
void AsyncCommunicator::RecvThread() { void AsyncCommunicator::RecvThread() {
...@@ -227,7 +227,7 @@ void AsyncCommunicator::RecvThread() { ...@@ -227,7 +227,7 @@ void AsyncCommunicator::RecvThread() {
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
} }
} }
VLOG(0) << "communicator stopped, recv thread exit"; VLOG(1) << "communicator stopped, recv thread exit";
} }
void AsyncCommunicator::Recv() { void AsyncCommunicator::Recv() {
...@@ -267,7 +267,7 @@ void AsyncCommunicator::RecvAll() { ...@@ -267,7 +267,7 @@ void AsyncCommunicator::RecvAll() {
} }
void AsyncCommunicator::Start() { void AsyncCommunicator::Start() {
VLOG(0) << "Communicator start"; VLOG(1) << "Communicator start";
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
} else { } else {
...@@ -284,7 +284,7 @@ void AsyncCommunicator::Start() { ...@@ -284,7 +284,7 @@ void AsyncCommunicator::Start() {
} }
void AsyncCommunicator::Stop() { void AsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop"; VLOG(1) << "Communicator stop";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
...@@ -300,7 +300,7 @@ void AsyncCommunicator::Stop() { ...@@ -300,7 +300,7 @@ void AsyncCommunicator::Stop() {
recv_thread_.reset(nullptr); recv_thread_.reset(nullptr);
} }
} }
VLOG(0) << "Communicator stop done"; VLOG(1) << "Communicator stop done";
} }
void AsyncCommunicator::Send(const std::vector<std::string> &var_names, void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
...@@ -385,11 +385,11 @@ void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, ...@@ -385,11 +385,11 @@ void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
} }
void GeoSgdCommunicator::Start() { void GeoSgdCommunicator::Start() {
VLOG(0) << "Geo Sgd Communicator start"; VLOG(1) << "Geo Sgd Communicator start";
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Geo Sgd Communicator is not inited, do nothing"; VLOG(0) << "Geo Sgd Communicator is not inited, do nothing";
} else { } else {
VLOG(0) << "start send thread "; VLOG(1) << "start send thread ";
running_ = true; running_ = true;
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
...@@ -398,7 +398,7 @@ void GeoSgdCommunicator::Start() { ...@@ -398,7 +398,7 @@ void GeoSgdCommunicator::Start() {
} }
void GeoSgdCommunicator::Stop() { void GeoSgdCommunicator::Stop() {
VLOG(0) << "Geo Sgd Communicator stop"; VLOG(1) << "Geo Sgd Communicator stop";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Geo Sgd Communicator is not inited, do nothing"; VLOG(0) << "Geo Sgd Communicator is not inited, do nothing";
...@@ -409,7 +409,7 @@ void GeoSgdCommunicator::Stop() { ...@@ -409,7 +409,7 @@ void GeoSgdCommunicator::Stop() {
send_thread_.reset(nullptr); send_thread_.reset(nullptr);
} }
} }
VLOG(0) << "Geo Sgd Communicator stop done"; VLOG(1) << "Geo Sgd Communicator stop done";
} }
void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names, void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names,
...@@ -463,7 +463,7 @@ void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names, ...@@ -463,7 +463,7 @@ void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names,
} }
void GeoSgdCommunicator::SendThread() { void GeoSgdCommunicator::SendThread() {
VLOG(0) << "SendThread start!"; VLOG(1) << "SendThread start!";
auto before_run_training = GetCurrentUS(); auto before_run_training = GetCurrentUS();
while (running_) { while (running_) {
...@@ -1024,6 +1024,19 @@ HalfAsyncCommunicator::~HalfAsyncCommunicator() { ...@@ -1024,6 +1024,19 @@ HalfAsyncCommunicator::~HalfAsyncCommunicator() {
if (consume_thread_) consume_thread_->join(); if (consume_thread_) consume_thread_->join();
} }
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
void HalfAsyncCommunicator::ConsumeThread() { void HalfAsyncCommunicator::ConsumeThread() {
VLOG(3) << "ConsumeThread start!"; VLOG(3) << "ConsumeThread start!";
while (running_) { while (running_) {
...@@ -1099,7 +1112,10 @@ void HalfAsyncCommunicator::ConsumeThread() { ...@@ -1099,7 +1112,10 @@ void HalfAsyncCommunicator::ConsumeThread() {
BarrierRecv(); BarrierRecv();
BarrierWeakUp(); BarrierWeakUp();
} }
VLOG(0) << "communicator stopped, send thread exit";
Clean();
VLOG(1) << "communicator stopped, send thread exit";
} }
void HalfAsyncCommunicator::Send(const std::vector<std::string> &var_names, void HalfAsyncCommunicator::Send(const std::vector<std::string> &var_names,
...@@ -1146,6 +1162,12 @@ void HalfAsyncCommunicator::Recv() { ...@@ -1146,6 +1162,12 @@ void HalfAsyncCommunicator::Recv() {
void HalfAsyncCommunicator::Barrier() { void HalfAsyncCommunicator::Barrier() {
barrier_counter_++; barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{ {
std::unique_lock<std::mutex> lk(barrier_mutex_); std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
...@@ -1171,7 +1193,7 @@ void HalfAsyncCommunicator::BarrierWeakUp() { ...@@ -1171,7 +1193,7 @@ void HalfAsyncCommunicator::BarrierWeakUp() {
} }
void HalfAsyncCommunicator::Start() { void HalfAsyncCommunicator::Start() {
VLOG(0) << "Communicator start"; VLOG(1) << "Communicator start";
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
} else { } else {
...@@ -1185,7 +1207,7 @@ void HalfAsyncCommunicator::Start() { ...@@ -1185,7 +1207,7 @@ void HalfAsyncCommunicator::Start() {
} }
void HalfAsyncCommunicator::Stop() { void HalfAsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop"; VLOG(1) << "Communicator stop";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
...@@ -1196,7 +1218,7 @@ void HalfAsyncCommunicator::Stop() { ...@@ -1196,7 +1218,7 @@ void HalfAsyncCommunicator::Stop() {
consume_thread_.reset(nullptr); consume_thread_.reset(nullptr);
} }
} }
VLOG(0) << "Communicator stop done"; VLOG(1) << "Communicator stop done";
} }
void SyncCommunicator::BarrierSend() { void SyncCommunicator::BarrierSend() {
......
...@@ -183,6 +183,8 @@ class Communicator { ...@@ -183,6 +183,8 @@ class Communicator {
virtual void Stop() = 0; virtual void Stop() = 0;
virtual bool IsRunning() { return running_; } virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual void Send(const std::vector<std::string>& var_names, virtual void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& var_tables, const std::vector<std::string>& var_tables,
const framework::Scope& scope) = 0; const framework::Scope& scope) = 0;
...@@ -309,6 +311,8 @@ class HalfAsyncCommunicator : public Communicator { ...@@ -309,6 +311,8 @@ class HalfAsyncCommunicator : public Communicator {
void Start() override; void Start() override;
void Stop() override; void Stop() override;
void Clean() override;
void Send(const std::vector<std::string>& var_names, void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& var_tables, const std::vector<std::string>& var_tables,
const framework::Scope& scope) override; const framework::Scope& scope) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册