// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/distributed/brpc_server.h" #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/brpc_variable_response.h" #include "paddle/fluid/operators/distributed/request_handler.h" namespace sendrecv { namespace distributed = paddle::operators::distributed; typedef std::unordered_map HandlerMap; class BRPCServiceImpl : public SendRecvService { public: explicit BRPCServiceImpl(const HandlerMap& rpc_call_map, distributed::RPCServer* rpc_server) : rpc_server_(rpc_server) { VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size(); auto it = rpc_call_map.find(distributed::kRequestSend); if (it != rpc_call_map.end()) { request_send_h_ = it->second; send_threads_.reset(new paddle::framework::ThreadPool( rpc_server_->GetThreadNum(distributed::kRequestSend))); } it = rpc_call_map.find(distributed::kRequestGet); if (it != rpc_call_map.end()) { request_get_h_ = it->second; get_threads_.reset(new paddle::framework::ThreadPool( rpc_server_->GetThreadNum(distributed::kRequestGet))); } it = rpc_call_map.find(distributed::kRequestPrefetch); if (it != rpc_call_map.end()) { request_prefetch_h_ = it->second; prefetch_threads_.reset(new paddle::framework::ThreadPool( rpc_server_->GetThreadNum(distributed::kRequestPrefetch))); } it = rpc_call_map.find(distributed::kRequestCheckpoint); if (it != rpc_call_map.end()) { request_checkpoint_h_ = it->second; checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool( rpc_server_->GetThreadNum(distributed::kRequestPrefetch))); } it = rpc_call_map.find(distributed::kRequestGetMonomerVariable); if (it != rpc_call_map.end()) { request_get_monomer_handler_h_ = it->second; } it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier); if (it != rpc_call_map.end()) { request_get_monomer_barrier_handler_h_ = it->second; } } virtual ~BRPCServiceImpl() {} void SendVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VoidMessage* response, google::protobuf::Closure* done) override { send_threads_->Run( [=] { _SendVariable(cntl_butil, request, response, done); }); } void _SendVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VoidMessage* response, google::protobuf::Closure* done) { PADDLE_ENFORCE(request_send_h_ != nullptr, "RequestSend handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); std::string varname = request->varname(); VLOG(3) << "RequestSend var_name:" << varname << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); distributed::BRPCVariableResponse resp(request_send_h_->scope(), request_send_h_->dev_ctx(), !request_send_h_->sync_mode()); PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0, "parse iobuf to tensor error!"); auto scope = resp.GetMutableLocalScope(); auto invar = resp.GetVar(); int trainer_id = request->trainer_id(); paddle::framework::Variable* outvar = nullptr; request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id); } void GetVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VariableMessage* response, google::protobuf::Closure* done) override { get_threads_->Run( [=] { _GetVariable(cntl_butil, request, response, done); }); } void _GetVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VariableMessage* response, google::protobuf::Closure* done) { PADDLE_ENFORCE(request_get_h_ != nullptr, "RequestGet handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); std::string varname = request->varname(); VLOG(3) << "RequestGet varname:" << varname << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); auto scope = request_get_h_->scope(); auto invar = scope->FindVar(varname); int trainer_id = request->trainer_id(); paddle::framework::Variable* outvar = nullptr; request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id); if (outvar) { distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(), response, &cntl->response_attachment(), "", false); } } void PrefetchVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VariableMessage* response, google::protobuf::Closure* done) override { prefetch_threads_->Run( [=] { _PrefetchVariable(cntl_butil, request, response, done); }); } void _PrefetchVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VariableMessage* response, google::protobuf::Closure* done) { PADDLE_ENFORCE(request_prefetch_h_ != nullptr, "kRequestPrefetch handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); // prefetch process... std::string in_var_name = request->varname(); std::string out_var_name = request->out_varname(); VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name << ", out_var_name: " << out_var_name << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); distributed::BRPCVariableResponse resp( request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true); PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0, "parse iobuf to tensor error!"); auto scope = resp.GetMutableLocalScope(); auto invar = scope->FindVar(in_var_name); std::string table_name = request->table_name(); int trainer_id = request->trainer_id(); paddle::framework::Variable* outvar = scope->Var(out_var_name); request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id, out_var_name, table_name); distributed::SerializeToIOBuf(out_var_name, outvar, *request_prefetch_h_->dev_ctx(), response, &cntl->response_attachment(), "", true); } void CheckpointNotify(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VoidMessage* response, google::protobuf::Closure* done) override { checkpoint_notify_threads_->Run( [=] { _CheckpointNotify(cntl_butil, request, response, done); }); } void _CheckpointNotify(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VoidMessage* response, google::protobuf::Closure* done) { PADDLE_ENFORCE( request_checkpoint_h_ != nullptr, "kRequestCheckpointNotify handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(), request_checkpoint_h_->dev_ctx()); auto scope = resp.GetMutableLocalScope(); std::string checkpoint_notify = request->varname(); std::string checkpoint_dir = request->out_varname(); int trainer_id = request->trainer_id(); VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify << ", dir: " << checkpoint_dir << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr, trainer_id, checkpoint_dir); } void GetMonomerVariable(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VariableMessage* response, google::protobuf::Closure* done) override { PADDLE_ENFORCE( request_get_monomer_handler_h_ != nullptr, "kRequestGetMonomerVariable handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); // proc request. std::string varname = request->varname(); VLOG(3) << "GetMonomerVariable " << varname << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); rpc_server_->WaitVarCond(varname); distributed::MonomerHandle h = rpc_server_->GetMonomer(varname); auto scope = h.scope_; auto invar = scope->FindVar(varname); paddle::framework::Variable* outvar = nullptr; request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar, request->trainer_id()); if (outvar) { distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response, &cntl->response_attachment(), "", false); } } void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil, const VariableMessage* request, VoidMessage* response, google::protobuf::Closure* done) override { PADDLE_ENFORCE( request_get_monomer_barrier_handler_h_ != nullptr, "RequestGetMonomerBarrier handler should be registed first!"); brpc::ClosureGuard done_guard(done); brpc::Controller* cntl = static_cast(cntl_butil); std::string varname = request->varname(); VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname << ", trainer_id:" << request->trainer_id() << ", from:" << cntl->remote_side(); rpc_server_->WaitVarCond(varname); distributed::MonomerHandle h = rpc_server_->GetMonomer(varname); paddle::framework::Scope* scope = nullptr; paddle::framework::Variable* invar = nullptr; paddle::framework::Variable* outvar = nullptr; request_get_monomer_barrier_handler_h_->Handle( varname, scope, invar, &outvar, request->trainer_id()); } private: distributed::RequestHandler* request_send_h_{nullptr}; distributed::RequestHandler* request_get_h_{nullptr}; distributed::RequestHandler* request_prefetch_h_{nullptr}; distributed::RequestHandler* request_checkpoint_h_{nullptr}; distributed::RequestHandler* request_get_monomer_handler_h_{nullptr}; distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr}; distributed::RPCServer* rpc_server_{nullptr}; // FIXME(gongwb): brpc should support process one rpce use one threadpool. std::unique_ptr send_threads_; std::unique_ptr get_threads_; std::unique_ptr prefetch_threads_; std::unique_ptr checkpoint_notify_threads_; }; } // namespace sendrecv namespace paddle { namespace operators { namespace distributed { void AsyncBRPCServer::StartServer() { // Instance of your service. sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this); // Add the service into server. Notice the second parameter, because the // service is put on stack, we don't want server to delete it, otherwise // use brpc::SERVER_OWNS_SERVICE. if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { LOG(FATAL) << "Fail to add service"; return; } brpc::ServerOptions options; #ifdef PADDLE_WITH_BRPC_RDMA options.use_rdma = true; #endif options.idle_timeout_sec = idle_timeout_s_; options.max_concurrency = max_concurrency_; if (server_.Start(bind_address_.c_str(), &options) != 0) { LOG(FATAL) << "Fail to start EchoServer" << bind_address_; return; } butil::EndPoint ep = server_.listen_address(); selected_port_ = ep.port; { std::lock_guard lock(this->mutex_ready_); ready_ = 1; } condition_ready_.notify_all(); server_.Join(); } void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); } void AsyncBRPCServer::WaitServerReady() { VLOG(3) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); VLOG(3) << "AsyncGRPCServer WaitSeverReady"; } }; // namespace distributed }; // namespace operators }; // namespace paddle