diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index c4219a429a53eb4869426a2674109555fb784b85..3a2527e407bb179c4873fa3ffe2e8f22fb47faf7 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name << " and dir:" << dir << " to " << epmap[i]; } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 35318a805898de645c844a2224f6df8c458d346c..4d60801b6a6ecaabf1165321e0cb19044d27aa34 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, req_count_++; } -void GRPCClient::Wait() { +bool GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); - sync_cond_.wait(lk, [this] { return req_count_ == 0; }); + sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); + return ok_; } void GRPCClient::Proceed() { @@ -297,6 +298,14 @@ void GRPCClient::Proceed() { if (c->status_.ok()) { VLOG(3) << c->var_h_.String() << " process"; c->Process(); + } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { + LOG(ERROR) << c->var_h_.String() + << " meets grpc error:" << c->status_.error_message(); + { + std::lock_guard lk(sync_mutex_); + ok_ = false; + } + sync_cond_.notify_all(); } else { LOG(FATAL) << c->var_h_.String() << " meets grpc error:" << c->status_.error_message(); diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 5dae20155edcf9edd746a5d9a9bbe0ccd789f431..d03a3e56aedbe4a008ee9ff187111f7635d14b58 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: - GRPCClient() {} + GRPCClient() : ok_(true) {} virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, @@ -221,7 +221,7 @@ class GRPCClient : public RPCClient { void AsyncSendEndPass(const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - void Wait() override; + bool Wait() override; void SendBeginPass() override; @@ -247,6 +247,7 @@ class GRPCClient : public RPCClient { std::mutex sync_mutex_; std::condition_variable sync_cond_; std::atomic req_count_{0}; + bool ok_; // mutex for GetChannel thread safety std::mutex chan_mutex_; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 6479d3a97bafba37b74a1d1c04852a6e60e01be8..4d87376fbf776e29156b78d826f5012bc53460df 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -72,7 +72,7 @@ class RPCClient { virtual void SendBeginPass() = 0; virtual void SendEndPass() = 0; - virtual void Wait() = 0; + virtual bool Wait() = 0; template static RPCClient* GetInstance() { diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 02beb80fc8a9f451393dcdd54492c4f88f908497..680fde19eefe57475b7526ebc29d4ff977a16977 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (auto& ep : eps) { VLOG(3) << "fetch barrier, ep: " << ep; rpc_client->AsyncSendFetchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 8734282fe496b8e90af19abd5549566d62316fc3..4b804740a06f9e29704f2b3f58a90191e3559347 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 9854a31f5b10f5ecd940c0d41c2c3e468fc17bad..1ba684014904e61a86bebacd7d29d7e10d313092 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 6b4572dcccc21e783f1df0b9bcde11d532ff4ba8..d7f8e994afd7e656bd5a9dd7c5ab45f0d52fe88b 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase { VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; // need to wait before sending send_barrier message - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); if (sync_mode) { for (auto& ep : eps) { VLOG(3) << "send barrier, ep: " << ep; rpc_client->AsyncSendBatchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 0cac329aafa8c4c67cae48ba62a48575f5edba92..829f310d4233c01a7fbb9ccf7427f6e47ce8d384 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { } } if (sync_send) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 69a067b99b2413b7ec990de43963182a022d29c6..227fd442c28a5da222efb877dae18b5e1922c66a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -67,7 +67,7 @@ bool IsCompiledWithCUDA() { } bool IsCompiledWithDIST() { -#ifdef PADDLE_WITH_DIST +#ifdef PADDLE_WITH_DISTRIBUTE return true; #else return false;