diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 7266f3276477891d3c7b6827316a428ef7a31c6e..ddeeebec58e02f1686fd2e3d3e5ac1a4c4fd3c59 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { +void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); @@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); req_count_++; +} - return true; +void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(FETCH_BARRIER_MESSAGE); + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + req_count_++; } bool RPCClient::Wait() { @@ -154,7 +164,7 @@ bool RPCClient::Proceed() { PADDLE_ENFORCE(tag); // TODO(gongwb): add more retries. - ClientBase* c = static_cast(tag); + BaseProcessor* c = static_cast(tag); if (!c->status_.ok()) { LOG(ERROR) << "proc param error:" << c->var_h_.String() << " grpc error:" << c->status_.error_message(); @@ -174,6 +184,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { } grpc::ChannelArguments args; + args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000); + args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 669838810de240358857300c1014aa622c17f808..f520367dd981288416631fdad15241fb5d811d07 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -52,14 +52,14 @@ struct VarHandle { void ProcGetResponse(const VarHandle& var_h, const sendrecv::VariableMessage& msg); -class ClientBase { +class BaseProcessor { public: - explicit ClientBase(std::shared_ptr ch) { + explicit BaseProcessor(std::shared_ptr ch) { stub_ = sendrecv::SendRecvService::NewStub(ch); context_ = NULL; } - virtual ~ClientBase() {} + virtual ~BaseProcessor() {} virtual void Prepare(const VarHandle& var_info, int64_t time_out) { context_.reset(new grpc::ClientContext()); @@ -91,9 +91,10 @@ class ClientBase { typedef std::function RequestSendCallBack; -class SendProcessor : public ClientBase { +class SendProcessor : public BaseProcessor { public: - explicit SendProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit SendProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~SendProcessor() {} @@ -110,9 +111,10 @@ class SendProcessor : public ClientBase { typedef std::function RequestGetCallBack; -class GetProcessor : public ClientBase { +class GetProcessor : public BaseProcessor { public: - explicit GetProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit GetProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~GetProcessor() {} @@ -126,10 +128,10 @@ class GetProcessor : public ClientBase { RequestGetCallBack response_call_back_ = ProcGetResponse; }; -class BatchBarrierProcessor : public ClientBase { +class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : ClientBase(ch) {} + : BaseProcessor(ch) {} virtual ~BatchBarrierProcessor() {} @@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase { sendrecv::VoidMessage reply_; }; +class FetchBarrierProcessor : public BaseProcessor { + public: + explicit FetchBarrierProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} + + virtual ~FetchBarrierProcessor() {} + + virtual void Process() {} + sendrecv::VariableMessage reply_; +}; + class RPCClient { public: bool AsyncSendVariable(const std::string& ep, @@ -151,7 +164,10 @@ class RPCClient { const std::string& var_name, int64_t time_out = 600 * 1000); - bool AsyncSendBatchBarrier(const std::string& ep, + void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + + void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); bool Wait(); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2a567516614aac39da9d069824f598ed26c5fb0c..8fff430cc4890925e4edba2fadb8eb7fc647d181 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -84,7 +84,7 @@ class RequestGet final : public RequestBase { explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - SimpleBlockQueue* queue) + SimpleBlockQueue* queue) : RequestBase(service, cq), responder_(&ctx_), scope_(scope), @@ -101,11 +101,16 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); - SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + if (var_name != FETCH_BARRIER_MESSAGE) { + SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + } // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; - queue_->Push('c'); + MessageWithName msg_with_name = + // request name reply + std::make_pair(var_name, std::move(reply_)); + queue_->Push(msg_with_name); } protected: @@ -114,12 +119,16 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter responder_; framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; - SimpleBlockQueue* queue_; + SimpleBlockQueue* queue_; }; void AsyncGRPCServer::WaitClientGet(int count) { - for (int i = 0; i < count; ++i) { - var_get_queue_.Pop(); + int fetch_barriers = 0; + while (fetch_barriers < count) { + auto msg = var_get_queue_.Pop(); + if (msg.first == FETCH_BARRIER_MESSAGE) { + fetch_barriers++; + } } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index e9402ff6aafe73684b57442c03ee4c573a902a46..b6666bcf96e484b0b17b935c0efb2930f19b19f2 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; - SimpleBlockQueue var_get_queue_; + SimpleBlockQueue var_get_queue_; // condition of the sub program std::mutex barrier_mutex_; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 5208091e54b4da2bb0265f84827ce23b57e954dc..4fa6aefd3e0b1bd45ac52b1eff3b29126d79f03a 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -32,6 +32,7 @@ namespace detail { #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" +#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" typedef void (*DestroyCallback)(void*); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 8e9923c87ce22ed229f78ef15430e50cab16c947..4253300788462a3704076fc79241a864f2f130a0 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase { } } if (exit_flag) { - rpc_service_->ShutDown(); rpc_service_->SetCond(1); + rpc_service_->ShutDown(); break; } try { @@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { } rpc_service_->SetCond(1); // FIXME(typhoonzero): use another condition to sync wait clients get. - rpc_service_->WaitClientGet(ins.size()); + rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true) } diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 8fdd08eae6b22cd57506d6e75182c1a7e2022562..443f40e803ea31c3961ed77842bd0775e0f74f35 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase { rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } PADDLE_ENFORCE(rpc_client->Wait()); + // tell pservers that current trainer have called fetch + for (auto& ep : endpoints) { + VLOG(3) << "send fetch barrier, ep: " << ep; + rpc_client->AsyncSendFetchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } } }; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index bb2ce4d45d5da6b2fbd097a94479f3696271e5ec..3d3a6c116eeb39fb7236d0e9707415cdd6b828bd 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -250,6 +250,8 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program self.program.global_block().delete_ops(self.optimize_ops) + # FIXME(typhoonzero): serialize once will fix error occurs when clone. + self.program.__str__() return self.program def get_pserver_program(self, endpoint): @@ -309,7 +311,8 @@ class DistributeTranspiler: for _, opt_op in enumerate(opt_op_on_pserver): if ufind.is_connected(op, opt_op): if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint) + self._append_pserver_ops(optimize_block, op, endpoint, + default_main_program()) else: self._append_pserver_non_opt_ops(optimize_block, op) break @@ -520,7 +523,8 @@ class DistributeTranspiler: orig_var_name = varname[:suff_idx] return orig_var_name - def _append_pserver_ops(self, optimize_block, opt_op, endpoint): + def _append_pserver_ops(self, optimize_block, opt_op, endpoint, + origin_program): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -576,7 +580,17 @@ class DistributeTranspiler: elif key == "LearningRate": # leraning rate variable has already be created by non-optimize op, # don't create it once again. - new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + lr_varname = opt_op.input(key)[0] + if pserver_block.vars.has_key(lr_varname): + new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + else: + origin_var = origin_program.global_block().vars[lr_varname] + tmpvar = pserver_block.create_var( + name=origin_var.name, + persistable=origin_var.persistable, + dtype=origin_var.dtype, + shape=origin_var.shape) + new_inputs[key] = tmpvar for key in opt_op.input_names: new_shape = None