提交 4a91a145 编写于 作者: Y Yancey1989

enforce rpc client timeout

上级 486121d5
...@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase {
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i]; << " and dir:" << dir << " to " << epmap[i];
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::Wait() { bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
return ok_;
} }
void GRPCClient::Proceed() { void GRPCClient::Proceed() {
...@@ -297,6 +298,14 @@ void GRPCClient::Proceed() { ...@@ -297,6 +298,14 @@ void GRPCClient::Proceed() {
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process"; VLOG(3) << c->var_h_.String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message();
{
std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false;
}
sync_cond_.notify_all();
} else { } else {
LOG(FATAL) << c->var_h_.String() LOG(FATAL) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
......
...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { ...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() {} GRPCClient() : ok_(true) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
...@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient { ...@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient {
void AsyncSendEndPass(const std::string& ep, void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; bool Wait() override;
void SendBeginPass() override; void SendBeginPass() override;
...@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient { ...@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient {
std::mutex sync_mutex_; std::mutex sync_mutex_;
std::condition_variable sync_cond_; std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
bool ok_;
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
......
...@@ -72,7 +72,7 @@ class RPCClient { ...@@ -72,7 +72,7 @@ class RPCClient {
virtual void SendBeginPass() = 0; virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0; virtual void SendEndPass() = 0;
virtual void Wait() = 0; virtual bool Wait() = 0;
template <typename T> template <typename T>
static RPCClient* GetInstance() { static RPCClient* GetInstance() {
......
...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep; VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
} }
if (sync_mode) { if (sync_mode) {
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
if (sync_mode) { if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { ...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase {
} }
} }
if (sync_send) { if (sync_send) {
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -67,7 +67,7 @@ bool IsCompiledWithCUDA() { ...@@ -67,7 +67,7 @@ bool IsCompiledWithCUDA() {
} }
bool IsCompiledWithDIST() { bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST #ifdef PADDLE_WITH_DISTRIBUTE
return true; return true;
#else #else
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册