diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5ce04cf1301fc8abb0cadf3c043cf0575dda39af..9a72e1baa34274201c40bd83a7aace549a7fc6ae 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor( } if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { - BCastParamsToDevs(bcast_vars); + BCastParamsToDevices(bcast_vars); } // Startup Program has been run. All local scopes has correct parameters. @@ -140,7 +140,7 @@ ParallelExecutor::ParallelExecutor( member_->places_, std::move(member_->executor_))); } -void ParallelExecutor::BCastParamsToDevs( +void ParallelExecutor::BCastParamsToDevices( const std::unordered_set &vars) const { // the initializing bcast, all vars would be bcast from device(0), // otherwise diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 6985b6540690c6218bcee51ba0e69f3d34812bfc..ffb9934a2d702b2bf6db7ad75a6bf9867e1e9901 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -66,7 +66,7 @@ class ParallelExecutor { void Run(const std::vector &fetch_tensors, const std::string &fetched_var_name); - void BCastParamsToDevs(const std::unordered_set &vars) const; + void BCastParamsToDevices(const std::unordered_set &vars) const; private: ParallelExecutorPrivate *member_; diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index c4219a429a53eb4869426a2674109555fb784b85..3a2527e407bb179c4873fa3ffe2e8f22fb47faf7 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name << " and dir:" << dir << " to " << epmap[i]; } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 35318a805898de645c844a2224f6df8c458d346c..4d60801b6a6ecaabf1165321e0cb19044d27aa34 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, req_count_++; } -void GRPCClient::Wait() { +bool GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); - sync_cond_.wait(lk, [this] { return req_count_ == 0; }); + sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); + return ok_; } void GRPCClient::Proceed() { @@ -297,6 +298,14 @@ void GRPCClient::Proceed() { if (c->status_.ok()) { VLOG(3) << c->var_h_.String() << " process"; c->Process(); + } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { + LOG(ERROR) << c->var_h_.String() + << " meets grpc error:" << c->status_.error_message(); + { + std::lock_guard lk(sync_mutex_); + ok_ = false; + } + sync_cond_.notify_all(); } else { LOG(FATAL) << c->var_h_.String() << " meets grpc error:" << c->status_.error_message(); diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 5dae20155edcf9edd746a5d9a9bbe0ccd789f431..d03a3e56aedbe4a008ee9ff187111f7635d14b58 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: - GRPCClient() {} + GRPCClient() : ok_(true) {} virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, @@ -221,7 +221,7 @@ class GRPCClient : public RPCClient { void AsyncSendEndPass(const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - void Wait() override; + bool Wait() override; void SendBeginPass() override; @@ -247,6 +247,7 @@ class GRPCClient : public RPCClient { std::mutex sync_mutex_; std::condition_variable sync_cond_; std::atomic req_count_{0}; + bool ok_; // mutex for GetChannel thread safety std::mutex chan_mutex_; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 6479d3a97bafba37b74a1d1c04852a6e60e01be8..4d87376fbf776e29156b78d826f5012bc53460df 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -72,7 +72,7 @@ class RPCClient { virtual void SendBeginPass() = 0; virtual void SendEndPass() = 0; - virtual void Wait() = 0; + virtual bool Wait() = 0; template static RPCClient* GetInstance() { diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 02beb80fc8a9f451393dcdd54492c4f88f908497..680fde19eefe57475b7526ebc29d4ff977a16977 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (auto& ep : eps) { VLOG(3) << "fetch barrier, ep: " << ep; rpc_client->AsyncSendFetchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 8734282fe496b8e90af19abd5549566d62316fc3..4b804740a06f9e29704f2b3f58a90191e3559347 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 9854a31f5b10f5ecd940c0d41c2c3e468fc17bad..1ba684014904e61a86bebacd7d29d7e10d313092 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 6b4572dcccc21e783f1df0b9bcde11d532ff4ba8..d7f8e994afd7e656bd5a9dd7c5ab45f0d52fe88b 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase { VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; // need to wait before sending send_barrier message - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); if (sync_mode) { for (auto& ep : eps) { VLOG(3) << "send barrier, ep: " << ep; rpc_client->AsyncSendBatchBarrier(ep); } - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 0cac329aafa8c4c67cae48ba62a48575f5edba92..829f310d4233c01a7fbb9ccf7427f6e47ce8d384 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { } } if (sync_send) { - rpc_client->Wait(); + PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); } } }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d8dc421bed711cfc1a149592c24b11c4ef115ec9..216c4666c0a311f93f29692b2ca1d17bf9dafab8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -68,7 +68,7 @@ bool IsCompiledWithCUDA() { } bool IsCompiledWithDIST() { -#ifdef PADDLE_WITH_DIST +#ifdef PADDLE_WITH_DISTRIBUTE return true; #else return false; @@ -669,7 +669,7 @@ All parameter, weight, gradient are variables in Paddle. const std::string &, Scope *, std::vector &, const ExecutionStrategy &, const BuildStrategy &, size_t, size_t>()) - .def("bcast_params", &ParallelExecutor::BCastParamsToDevs) + .def("bcast_params", &ParallelExecutor::BCastParamsToDevices) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope*