From a585b585dd3f16c45d57a88bf2f6f0bb101d0398 Mon Sep 17 00:00:00 2001 From: Yancey Date: Mon, 29 Jan 2018 16:48:05 +0800 Subject: [PATCH] Batch barrier in send/recv op (#7847) * initialize batch barrier * add some comments * update * fix batch barrier * use sendvariable rpc interface to send batch barrier * fix comment * fix method * fix by comment * fix by comment --- paddle/operators/detail/grpc_client.cc | 15 ++++++ paddle/operators/detail/grpc_client.h | 24 ++++++++++ paddle/operators/detail/grpc_server.cc | 13 ++--- paddle/operators/detail/grpc_server.h | 3 +- paddle/operators/detail/sendrecvop_utils.h | 3 ++ paddle/operators/recv_op.cc | 55 +++++++++++++--------- paddle/operators/send_op.cc | 12 ++++- 7 files changed, 92 insertions(+), 33 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index c4339455426..9b5f7afc6a4 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } +bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(BATCH_BARRIER_MESSAGE); + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + req_count_++; + + return true; +} + bool RPCClient::Wait() { if (req_count_ <= 0) { return true; diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index a62e70a2533..f9499f6dc70 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -71,6 +71,15 @@ class ClientBase { context_->set_deadline(deadline); } + virtual void Prepare(int64_t time_out) { + context_.reset(new grpc::ClientContext()); + + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); + + context_->set_deadline(deadline); + } + virtual void Process() = 0; std::unique_ptr stub_; @@ -117,6 +126,17 @@ class GetProcessor : public ClientBase { RequestGetCallBack response_call_back_ = ProcGetResponse; }; +class BatchBarrierProcessor : public ClientBase { + public: + explicit BatchBarrierProcessor(std::shared_ptr ch) + : ClientBase(ch) {} + + virtual ~BatchBarrierProcessor() {} + + virtual void Process() {} + sendrecv::VoidMessage reply_; +}; + class RPCClient { public: bool AsyncSendVariable(const std::string& ep, @@ -130,6 +150,10 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = 600 * 1000); + + bool AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + bool Wait(); private: diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index 3ddcd839bdd..4f94e1315fb 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() { cq_send_ = builder.AddCompletionQueue(); cq_get_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); LOG(INFO) << "Server listening on " << address_ << std::endl; @@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() { std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); t_send_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); t_get_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); // wait server @@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { } RequestSend* send = new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); - VLOG(4) << "create RequestSend status:" << send->Status(); + VLOG(4) << "Create RequestSend status:" << send->Status(); } void AsyncGRPCServer::TryToRegisterNewGetOne() { @@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { } RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, &var_get_queue_); - VLOG(4) << "create Requestget status:" << get->Status(); + VLOG(4) << "Create RequestGet status:" << get->Status(); } -// FIXME(typhoonzero): remove wait argument and change cq_name to enum. -void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, +// FIXME(typhoonzero): change cq_name to enum. +void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { TryToRegisterNewOne(); diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 1ca9086c744..3f8b9d93176 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void ShutDown(); protected: - void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, - std::string cq_name, + void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); diff --git a/paddle/operators/detail/sendrecvop_utils.h b/paddle/operators/detail/sendrecvop_utils.h index bc6581afab9..8e66f7299c7 100644 --- a/paddle/operators/detail/sendrecvop_utils.h +++ b/paddle/operators/detail/sendrecvop_utils.h @@ -30,6 +30,9 @@ namespace paddle { namespace operators { namespace detail { +#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" +#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" + void SerializeToMessage(const std::string& name, const framework::Variable* var, const platform::DeviceContext& ctx, sendrecv::VariableMessage* msg); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 5d1df566aff..49e1eb34024 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -29,8 +29,6 @@ limitations under the License. */ #include "paddle/operators/detail/simple_block_queue.h" #include "paddle/string/printf.h" -#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" - namespace paddle { namespace operators { @@ -95,7 +93,6 @@ class RecvOp : public framework::OperatorBase { auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); auto fan_in = Attr("Fanin"); - size_t param_count = param_list.size(); auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); @@ -103,38 +100,50 @@ class RecvOp : public framework::OperatorBase { // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; - size_t barrier_size = param_count * fan_in; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. rpc_service_->SetCond(0); - for (size_t i = 0; i < barrier_size; ++i) { + size_t recv_var_cnt = 0; + int batch_barrier = 0; + while (batch_barrier != fan_in) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; exit_flag = true; break; - } - auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; + } else if (grad_var_name == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "recv batch barrier message"; + batch_barrier++; + continue; } else { - LOG(ERROR) << "grad has no paired param:" << grad_var_name; - } - VLOG(3) << "received grad: " << grad_var_name - << " updating param: " << param_var_name; - if (fan_in > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } - auto *var = recv_scope.FindVar(grad_var_name); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << grad_var_name; - PADDLE_THROW("Can not find server side var"); + // receive a variable + recv_var_cnt++; + auto it = + std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad has no paired param:" << grad_var_name; + } + VLOG(3) << "received grad: " << grad_var_name + << " updating param: " << param_var_name; + + if (fan_in > 1) { + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + } + auto *var = recv_scope.FindVar(grad_var_name); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << grad_var_name; + PADDLE_THROW("Can not find server side var"); + } + detail::DeserializeFromMessage(v.second, dev_ctx, var); } - detail::DeserializeFromMessage(v.second, dev_ctx, var); } + VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; + // TODO(Yancey1989): merge SelectedRows variables here if (exit_flag) { break; } @@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase { LOG(ERROR) << "run sub program error " << e.what(); } rpc_service_->SetCond(1); - rpc_service_->WaitClientGet(barrier_size); + rpc_service_->WaitClientGet(recv_var_cnt); grads_counter_.clear(); } // while(true) } diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 5aa66c20eaf..bb719dc2a8a 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -37,17 +37,25 @@ class SendOp : public framework::OperatorBase { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + std::vector endpoints = + Attr>("endpoints"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); for (size_t i = 0; i < ins.size(); i++) { - VLOG(3) << "sending " << ins[i]; + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } PADDLE_ENFORCE(client_.Wait()); + for (auto& ep : endpoints) { + VLOG(3) << "batch barrier, ep: " << ep; + client_.AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(client_.Wait()); + for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } -- GitLab