diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index bc930cbb007b73b6bbf9d4e999ee6c4388c8d0f8..47decb6d7eb763322b2af26d8531b12e816b0305 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, } Status SendRecvServerImpl::GetVariable(ServerContext *context, - const VoidMessage *in_var, + const VariableMessage *in_var, VariableMessage *out_var) { - // Block util the sub graph is done. - auto out_tensor_with_name = var_return_queue_.Pop(); + std::string get_var_name = in_var->varname(); + auto *var = scope_->FindVar(get_var_name); + auto tensor = var->Get(); std::ostringstream oss; - framework::SerializeToStream(oss, out_tensor_with_name.second, - platform::CPUDeviceContext()); + framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext()); std::string *varname = out_var->mutable_varname(); - *varname = out_tensor_with_name.first; + *varname = get_var_name; std::string *serialized = out_var->mutable_serialized(); *serialized = oss.str(); return Status::OK; } +Status SendRecvServerImpl::Wait(ServerContext *context, + const VoidMessage *in_var, + VoidMessage *out_var) { + std::unique_lock lock(this->mutex_); + condition_.wait(lock, [=] { return this->done_ == true; }); + return Status::OK; +} + +void SendRecvServerImpl::Start() { + std::unique_lock lock(this->mutex_); + done_ = false; +} + +void SendRecvServerImpl::Done() { + std::unique_lock lock(this->mutex_); + done_ = true; + condition_.notify_all(); +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc index bf22d3df818358de5362c82c60e955b976238483..7555cc63fb24e03c71fcc8a0da55e198ddc57feb 100644 --- a/paddle/operators/detail/send_impl.cc +++ b/paddle/operators/detail/send_impl.cc @@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope, return true; } -bool RPCClient::GetVariable(const framework::Scope& scope) { +bool RPCClient::GetVariable(const framework::Scope& scope, + const std::string& outname) { ClientContext context; - VariableMessage msg; - VoidMessage void_msg; + VariableMessage call_msg, ret_msg; + call_msg.set_varname(outname); auto ctx = platform::CPUDeviceContext(); - Status status = stub_->GetVariable(&context, void_msg, &msg); + Status status = stub_->GetVariable(&context, call_msg, &ret_msg); if (!status.ok()) { LOG(ERROR) << "gRPC error: " << status.error_message(); return false; } - std::istringstream iss(msg.serialized()); - auto outname = msg.varname(); + std::istringstream iss(ret_msg.serialized()); + framework::LoDTensor ret_tensor; framework::DeserializeFromStream(iss, &ret_tensor); auto* outvar = scope.FindVar(outname); diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index d00c33fe42af1c63435db8c730a1d7b789420d12..ce729908062ad442e66cc00001e14ceb6f268560 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -22,7 +22,9 @@ service SendRecvService { // TODO(typhoonzero): add streaming API rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. - rpc GetVariable(VoidMessage) returns (VariableMessage) {} + rpc GetVariable(VariableMessage) returns (VariableMessage) {} + // wait for one execution of the program + rpc Wait(VoidMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index df01345e342789d5816944f2e3637ea64f0c6960..6edbb2d83482c92c60f96ed0f8de4ff89f6fae24 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -20,10 +20,6 @@ #include "paddle/framework/selected_rows.h" #include "paddle/operators/detail/simple_block_queue.h" -// #include -// #include -// #include -// #include #include "paddle/operators/detail/send_recv.grpc.pb.h" #include "paddle/operators/detail/send_recv.pb.h" @@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service { Status SendVariable(ServerContext *context, const VariableMessage *in_var, VoidMessage *out_var) override; - Status GetVariable(ServerContext *context, const VoidMessage *in_var, + Status GetVariable(ServerContext *context, const VariableMessage *in_var, VariableMessage *out_var) override; + Status Wait(ServerContext *context, const VoidMessage *in_var, + VoidMessage *out_var) override; + void Start(); + void Done(); + void SetScope(framework::Scope *scope) { scope_ = scope; }; const TensorWithName Get() { return this->var_recv_queue_.Pop(); } - void Push(const TensorWithName &var) { this->var_return_queue_.Push(var); } - private: // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; - // calculated variable should push to this queue. - SimpleBlockQueue var_return_queue_; + framework::Scope *scope_; + // condition of the sub program + std::mutex mutex_; + bool done_; + std::condition_variable condition_; }; // RPCClient is a class to send tensors to pserver sub-network @@ -78,7 +80,7 @@ class RPCClient { : stub_(SendRecvService::NewStub(channel)) {} bool SendVariable(const framework::Scope &scope, const std::string &inname); - bool GetVariable(const framework::Scope &scope); + bool GetVariable(const framework::Scope &scope, const std::string &outname); private: std::unique_ptr stub_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 9c3e8953bb781079c5fbca611182d083350c2ea3..9af8d311d92393fe2b8de5859ad6e591cfb48b8d 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase { const platform::DeviceContext &dev_ctx) const override { // FIXME(typhoonzero): no new scopes for every run. framework::Scope &recv_scope = scope.NewScope(); + rpc_service_.SetScope(&recv_scope); auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. while (true) { + rpc_service_.Start(); // Get from multiple trainers, we don't care about order in which // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { @@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase { LOG(ERROR) << "run sub program error " << e.what(); } - for (size_t i = 0; i < param_count; ++i) { - auto *out_var = recv_scope.FindVar(param_list[i]); - detail::TensorWithName out; - out.first = param_list[i]; - out.second = out_var->Get(); - rpc_service_->Push(out); - } + // for (size_t i = 0; i < param_count; ++i) { + // auto *out_var = recv_scope.FindVar(param_list[i]); + // detail::TensorWithName out; + // out.first = param_list[i]; + // out.second = out_var->Get(); + // rpc_service_->Push(out); + // } } // while(true) }