未验证 提交 0042ba93 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #12127 from Yancey1989/enforce_rpc_timeout

Enforce rpc timeout
...@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
} }
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevs(bcast_vars); BCastParamsToDevices(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
...@@ -140,7 +140,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -140,7 +140,7 @@ ParallelExecutor::ParallelExecutor(
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
} }
void ParallelExecutor::BCastParamsToDevs( void ParallelExecutor::BCastParamsToDevices(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
// the initializing bcast, all vars would be bcast from device(0), // the initializing bcast, all vars would be bcast from device(0),
// otherwise // otherwise
......
...@@ -66,7 +66,7 @@ class ParallelExecutor { ...@@ -66,7 +66,7 @@ class ParallelExecutor {
void Run(const std::vector<std::string> &fetch_tensors, void Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name); const std::string &fetched_var_name);
void BCastParamsToDevs(const std::unordered_set<std::string> &vars) const; void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
......
...@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase {
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i]; << " and dir:" << dir << " to " << epmap[i];
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::Wait() { bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
return ok_;
} }
void GRPCClient::Proceed() { void GRPCClient::Proceed() {
...@@ -297,6 +298,14 @@ void GRPCClient::Proceed() { ...@@ -297,6 +298,14 @@ void GRPCClient::Proceed() {
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process"; VLOG(3) << c->var_h_.String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message();
{
std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false;
}
sync_cond_.notify_all();
} else { } else {
LOG(FATAL) << c->var_h_.String() LOG(FATAL) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
......
...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { ...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() {} GRPCClient() : ok_(true) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
...@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient { ...@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient {
void AsyncSendEndPass(const std::string& ep, void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; bool Wait() override;
void SendBeginPass() override; void SendBeginPass() override;
...@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient { ...@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient {
std::mutex sync_mutex_; std::mutex sync_mutex_;
std::condition_variable sync_cond_; std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
bool ok_;
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
......
...@@ -72,7 +72,7 @@ class RPCClient { ...@@ -72,7 +72,7 @@ class RPCClient {
virtual void SendBeginPass() = 0; virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0; virtual void SendEndPass() = 0;
virtual void Wait() = 0; virtual bool Wait() = 0;
template <typename T> template <typename T>
static RPCClient* GetInstance() { static RPCClient* GetInstance() {
......
...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep; VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
} }
if (sync_mode) { if (sync_mode) {
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
if (sync_mode) { if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { ...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase {
} }
} }
if (sync_send) { if (sync_send) {
rpc_client->Wait(); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
} }
}; };
......
...@@ -68,7 +68,7 @@ bool IsCompiledWithCUDA() { ...@@ -68,7 +68,7 @@ bool IsCompiledWithCUDA() {
} }
bool IsCompiledWithDIST() { bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST #ifdef PADDLE_WITH_DISTRIBUTE
return true; return true;
#else #else
return false; return false;
...@@ -669,7 +669,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -669,7 +669,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::string &, Scope *, std::vector<Scope *> &, const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t, const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>()) size_t>())
.def("bcast_params", &ParallelExecutor::BCastParamsToDevs) .def("bcast_params", &ParallelExecutor::BCastParamsToDevices)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册