未验证 提交 8cee9f61 编写于 作者: G gongweibao 提交者: GitHub

Fix rpcclient's wait action in aync env. (#13307)

上级 7f692b8e
...@@ -20,6 +20,7 @@ if(WITH_GRPC) ...@@ -20,6 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
return() return()
endif() endif()
......
...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() { ...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
} }
channels_.clear(); channels_.clear();
} }
client_thread_->join(); client_thread_->join();
} }
bool GRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Send";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = nullptr; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
} }
template <typename T> template <typename T>
...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { ...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp); result->Swap(&tmp);
} }
bool GRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Get";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -160,10 +145,10 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -160,10 +145,10 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
req_count_++; req_count_++;
return true; return h;
} }
bool GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
...@@ -175,27 +160,21 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -175,27 +160,21 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, s, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = out_var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Prefetch";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(CHECKPOINT_SAVE_MESSAGE);
...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
bool GRPCClient::Wait() { bool GRPCClient::Wait() {
...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() { ...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() {
void* tag = nullptr; void* tag = nullptr;
bool ok = false; bool ok = false;
VLOG(3) << "GRPCClient Proceed begin";
while (!stopped_ && cq_.Next(&tag, &ok)) { while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE(c);
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String() LOG(ERROR) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false; ok_ = false;
} }
sync_cond_.notify_all(); c->Finish(false);
} else { } else {
LOG(FATAL) << c->var_h_.String() LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
c->Finish(false);
} }
delete c; delete c;
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() { ...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() {
} }
sync_cond_.notify_all(); sync_cond_.notify_all();
} }
VLOG(3) << "GRPCClient Proceed end";
} }
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); ...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { BaseProcessor() { context_ = nullptr; }
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(VarHandlePtr h, int64_t time_out) {
var_h_ = h;
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true); context_->set_wait_for_ready(true);
if (time_out) { if (time_out) {
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
...@@ -71,21 +70,21 @@ class BaseProcessor { ...@@ -71,21 +70,21 @@ class BaseProcessor {
} }
} }
virtual void Prepare(int64_t time_out) { void Process() {
context_.reset(new grpc::ClientContext()); ProcessImpl();
context_->set_wait_for_ready(true); var_h_->Finish(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
} }
virtual void Process() = 0; VarHandlePtr GetVarHandlePtr() { return var_h_; }
bool Wait() { return var_h_->Wait(); }
void Finish(bool ok) { return var_h_->Finish(ok); }
virtual void ProcessImpl() = 0;
std::unique_ptr<grpc::ClientContext> context_; std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_; grpc::Status status_;
VarHandle var_h_;
protected:
VarHandlePtr var_h_;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class SendProcessor : public BaseProcessor { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class GetProcessor : public BaseProcessor { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor { ...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor {
class BatchBarrierProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor { ...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor {
class FetchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor {
public: public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~FetchBarrierProcessor() {} virtual ~FetchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VariableMessage reply_; sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor {
class CheckpointNotifyProcessor : public BaseProcessor { class CheckpointNotifyProcessor : public BaseProcessor {
public: public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch) explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~CheckpointNotifyProcessor() {} virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient { ...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient {
GRPCClient() : ok_(true), completed_(false), stopped_(false) {} GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncSendVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr AsyncSendBatchBarrier(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr AsyncSendFetchBarrier(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendComplete(const std::string& ep, VarHandlePtr AsyncSendComplete(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer; class RPCServer;
struct VarHandle { class VarHandle {
// RPC endpoint. public:
std::string ep; VarHandle(const std::string ep, const std::string& method,
const platform::DeviceContext* ctx; const std::string& name,
const framework::Scope* scope; const platform::DeviceContext* p_ctx = nullptr,
// Variable name. const framework::Scope* p_scope = nullptr)
std::string name; : ok_(kVarHandleDefaultState) {
// RPC method name. ep_ = ep;
std::string method; ctx_ = p_ctx;
scope_ = p_scope;
name_ = name;
method_ = method;
}
virtual ~VarHandle() {}
public:
bool Wait() {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; });
}
VLOG(7) << "VarHandle wait:" << ok_;
return ok_ != 0;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
ok_ = ok;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
}
std::string String() const { std::string String() const {
std::ostringstream s; std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]"; s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_
<< "]";
return s.str(); return s.str();
} }
std::string ep() const { return ep_; }
const platform::DeviceContext* ctx() const { return ctx_; }
const framework::Scope* scope() const { return scope_; }
std::string name() const { return name_; }
std::string method() const { return method_; }
protected:
// RPC endpoint.
std::string ep_;
const platform::DeviceContext* ctx_;
const framework::Scope* scope_;
// Variable name.
std::string name_;
// RPC method name.
std::string method_;
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
int ok_;
static const int kVarHandleDefaultState = -1;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
}; };
typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(bool sync_mode)
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#pragma once #pragma once
#include <condition_variable> // NOLINT
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline); DECLARE_int32(rpc_deadline);
...@@ -31,37 +33,36 @@ class RPCClient { ...@@ -31,37 +33,36 @@ class RPCClient {
public: public:
RPCClient() {} RPCClient() {}
virtual ~RPCClient() {} virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep, virtual VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep, virtual VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep, virtual VarHandlePtr AsyncPrefetchVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& in_var_name,
const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep, virtual VarHandlePtr AsyncSendBatchBarrier(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual VarHandlePtr AsyncSendFetchBarrier(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep, virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& dir, const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendComplete(const std::string& ep, virtual VarHandlePtr AsyncSendComplete(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
// Complete tells all the pserver instances that finishe the training, // Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train // the pserver can reduce it's barrier count, and continue to train
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
using paddle::operators::distributed::VarHandlePtr;
using paddle::operators::distributed::VarHandle;
void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); }
void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); }
TEST(VarHandle, Run) {
std::vector<VarHandlePtr> a;
for (int i = 0; i < 12; i++) {
VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr));
a.push_back(s);
}
std::vector<std::unique_ptr<std::thread>> t;
for (int i = 0; i < 6; i++) {
t.emplace_back(new std::thread(WaitFalse, a[i]));
}
for (int i = 0; i < 6; i++) {
a[i]->Finish(false);
t[i]->join();
}
for (int i = 6; i < 12; i++) {
t.emplace_back(new std::thread(WaitTrue, a[i]));
}
for (int i = 6; i < 12; i++) {
a[i]->Finish(true);
t[i]->join();
}
}
...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope,
ins[i], outs[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
}; };
......
...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
} }
if (sync_mode) { if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <future> // NOLINT #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase { ...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
if (sync_send) { if (sync_send) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册