From 4a91a14549812961cecdf612a10371bacfe543c5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 12 Jul 2018 23:56:48 +0800 Subject: [PATCH] enforce rpc client timeout --- paddle/fluid/operators/checkpoint_notify_op.cc | 2 +- paddle/fluid/operators/distributed/grpc_client.cc | 13 +++++++++++-- paddle/fluid/operators/distributed/grpc_client.h | 5 +++-- paddle/fluid/operators/distributed/rpc_client.h | 2 +- paddle/fluid/operators/fetch_barrier_op.cc | 4 ++-- paddle/fluid/operators/prefetch_op.cc | 2 +- paddle/fluid/operators/recv_op.cc | 2 +- paddle/fluid/operators/send_barrier_op.cc | 4 ++-- paddle/fluid/operators/send_op.cc | 2 +- paddle/fluid/pybind/pybind.cc | 2 +- 10 files changed, 24 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index c4219a429a5..3a2527e407b 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name << " and dir:" << dir << " to " << epmap[i]; } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 35318a80589..4d60801b6a6 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, req_count_++; } -void GRPCClient::Wait() { +bool GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); - sync_cond_.wait(lk, [this] { return req_count_ == 0; }); + sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); + return ok_; } void GRPCClient::Proceed() { @@ -297,6 +298,14 @@ void GRPCClient::Proceed() { if (c->status_.ok()) { VLOG(3) << c->var_h_.String() << " process"; c->Process(); + } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { + LOG(ERROR) << c->var_h_.String() + << " meets grpc error:" << c->status_.error_message(); + { + std::lock_guard lk(sync_mutex_); + ok_ = false; + } + sync_cond_.notify_all(); } else { LOG(FATAL) << c->var_h_.String() << " meets grpc error:" << c->status_.error_message(); diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 5dae20155ed..d03a3e56aed 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: - GRPCClient() {} + GRPCClient() : ok_(true) {} virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, @@ -221,7 +221,7 @@ class GRPCClient : public RPCClient { void AsyncSendEndPass(const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - void Wait() override; + bool Wait() override; void SendBeginPass() override; @@ -247,6 +247,7 @@ class GRPCClient : public RPCClient { std::mutex sync_mutex_; std::condition_variable sync_cond_; std::atomic req_count_{0}; + bool ok_; // mutex for GetChannel thread safety std::mutex chan_mutex_; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 6479d3a97ba..4d87376fbf7 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -72,7 +72,7 @@ class RPCClient { virtual void SendBeginPass() = 0; virtual void SendEndPass() = 0; - virtual void Wait() = 0; + virtual bool Wait() = 0; template static RPCClient* GetInstance() { diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 02beb80fc8a..680fde19eef 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (auto& ep : eps) { VLOG(3) << "fetch barrier, ep: " << ep; rpc_client->AsyncSendFetchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 8734282fe49..4b804740a06 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 9854a31f5b1..1ba68401490 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 6b4572dcccc..d7f8e994afd 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase { VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; // need to wait before sending send_barrier message - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); if (sync_mode) { for (auto& ep : eps) { VLOG(3) << "send barrier, ep: " << ep; rpc_client->AsyncSendBatchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 0cac329aafa..829f310d423 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { } } if (sync_send) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 69a067b99b2..227fd442c28 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -67,7 +67,7 @@ bool IsCompiledWithCUDA() { } bool IsCompiledWithDIST() { -#ifdef PADDLE_WITH_DIST +#ifdef PADDLE_WITH_DISTRIBUTE return true; #else return false; -- GitLab