未验证 提交 99cf2e61 编写于 作者: T tangwei12 提交者: GitHub

fix communicator when break under pyreder mode (#22911) (#22919)

* fix communicator when breaking under PyReader mode, test=develop
* revert some vlog level to 0, test=develop
上级 f3ce7dda
......@@ -213,7 +213,7 @@ void AsyncCommunicator::SendThread() {
<< after_run_send_graph - before_run_send_graph;
Recv();
}
VLOG(0) << "communicator stopped, send thread exit";
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::RecvThread() {
......@@ -227,7 +227,7 @@ void AsyncCommunicator::RecvThread() {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
VLOG(0) << "communicator stopped, recv thread exit";
VLOG(1) << "communicator stopped, recv thread exit";
}
void AsyncCommunicator::Recv() {
......@@ -267,7 +267,7 @@ void AsyncCommunicator::RecvAll() {
}
void AsyncCommunicator::Start() {
VLOG(0) << "Communicator start";
VLOG(1) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
......@@ -284,7 +284,7 @@ void AsyncCommunicator::Start() {
}
void AsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop";
VLOG(1) << "Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
......@@ -300,7 +300,7 @@ void AsyncCommunicator::Stop() {
recv_thread_.reset(nullptr);
}
}
VLOG(0) << "Communicator stop done";
VLOG(1) << "Communicator stop done";
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
......@@ -385,11 +385,11 @@ void GeoSgdCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
}
void GeoSgdCommunicator::Start() {
VLOG(0) << "Geo Sgd Communicator start";
VLOG(1) << "Geo Sgd Communicator start";
if (!communicator_) {
VLOG(0) << "Geo Sgd Communicator is not inited, do nothing";
} else {
VLOG(0) << "start send thread ";
VLOG(1) << "start send thread ";
running_ = true;
// start send and recv thread
send_thread_.reset(
......@@ -398,7 +398,7 @@ void GeoSgdCommunicator::Start() {
}
void GeoSgdCommunicator::Stop() {
VLOG(0) << "Geo Sgd Communicator stop";
VLOG(1) << "Geo Sgd Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Geo Sgd Communicator is not inited, do nothing";
......@@ -409,7 +409,7 @@ void GeoSgdCommunicator::Stop() {
send_thread_.reset(nullptr);
}
}
VLOG(0) << "Geo Sgd Communicator stop done";
VLOG(1) << "Geo Sgd Communicator stop done";
}
void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names,
......@@ -463,7 +463,7 @@ void GeoSgdCommunicator::Send(const std::vector<std::string> &sparse_var_names,
}
void GeoSgdCommunicator::SendThread() {
VLOG(0) << "SendThread start!";
VLOG(1) << "SendThread start!";
auto before_run_training = GetCurrentUS();
while (running_) {
......@@ -1024,6 +1024,19 @@ HalfAsyncCommunicator::~HalfAsyncCommunicator() {
if (consume_thread_) consume_thread_->join();
}
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
void HalfAsyncCommunicator::ConsumeThread() {
VLOG(3) << "ConsumeThread start!";
while (running_) {
......@@ -1099,7 +1112,10 @@ void HalfAsyncCommunicator::ConsumeThread() {
BarrierRecv();
BarrierWeakUp();
}
VLOG(0) << "communicator stopped, send thread exit";
Clean();
VLOG(1) << "communicator stopped, send thread exit";
}
void HalfAsyncCommunicator::Send(const std::vector<std::string> &var_names,
......@@ -1146,6 +1162,12 @@ void HalfAsyncCommunicator::Recv() {
void HalfAsyncCommunicator::Barrier() {
barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{
std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
......@@ -1171,7 +1193,7 @@ void HalfAsyncCommunicator::BarrierWeakUp() {
}
void HalfAsyncCommunicator::Start() {
VLOG(0) << "Communicator start";
VLOG(1) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
......@@ -1185,7 +1207,7 @@ void HalfAsyncCommunicator::Start() {
}
void HalfAsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop";
VLOG(1) << "Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
......@@ -1196,7 +1218,7 @@ void HalfAsyncCommunicator::Stop() {
consume_thread_.reset(nullptr);
}
}
VLOG(0) << "Communicator stop done";
VLOG(1) << "Communicator stop done";
}
void SyncCommunicator::BarrierSend() {
......
......@@ -183,6 +183,8 @@ class Communicator {
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& var_tables,
const framework::Scope& scope) = 0;
......@@ -309,6 +311,8 @@ class HalfAsyncCommunicator : public Communicator {
void Start() override;
void Stop() override;
void Clean() override;
void Send(const std::vector<std::string>& var_names,
const std::vector<std::string>& var_tables,
const framework::Scope& scope) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册