未验证 提交 0042ba93 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #12127 from Yancey1989/enforce_rpc_timeout

Enforce rpc timeout
......@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
}
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevs(bcast_vars);
BCastParamsToDevices(bcast_vars);
}
// Startup Program has been run. All local scopes has correct parameters.
......@@ -140,7 +140,7 @@ ParallelExecutor::ParallelExecutor(
member_->places_, std::move(member_->executor_)));
}
void ParallelExecutor::BCastParamsToDevs(
void ParallelExecutor::BCastParamsToDevices(
const std::unordered_set<std::string> &vars) const {
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
......
......@@ -66,7 +66,7 @@ class ParallelExecutor {
void Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name);
void BCastParamsToDevs(const std::unordered_set<std::string> &vars) const;
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
private:
ParallelExecutorPrivate *member_;
......
......@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase {
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req_count_++;
}
void GRPCClient::Wait() {
bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
return ok_;
}
void GRPCClient::Proceed() {
......@@ -297,6 +298,14 @@ void GRPCClient::Proceed() {
if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process";
c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message();
{
std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false;
}
sync_cond_.notify_all();
} else {
LOG(FATAL) << c->var_h_.String()
<< " meets grpc error:" << c->status_.error_message();
......
......@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient {
public:
GRPCClient() {}
GRPCClient() : ok_(true) {}
virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
......@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient {
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;
bool Wait() override;
void SendBeginPass() override;
......@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient {
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
bool ok_;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
......
......@@ -72,7 +72,7 @@ class RPCClient {
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void Wait() = 0;
virtual bool Wait() = 0;
template <typename T>
static RPCClient* GetInstance() {
......
......@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
}
};
......
......@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
if (sync_mode) {
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
}
};
......
......@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase {
}
}
if (sync_send) {
rpc_client->Wait();
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
}
};
......
......@@ -68,7 +68,7 @@ bool IsCompiledWithCUDA() {
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST
#ifdef PADDLE_WITH_DISTRIBUTE
return true;
#else
return false;
......@@ -669,7 +669,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>())
.def("bcast_params", &ParallelExecutor::BCastParamsToDevs)
.def("bcast_params", &ParallelExecutor::BCastParamsToDevices)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册