// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/distributed/brpc_client.h" #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { namespace distributed { DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds"); DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)"); BRPCClient::~BRPCClient() { Wait(); } void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response, VarHandlePtr var_h, ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx, BRPCClient* cls) { // std::unique_ptr makes sure cntl/response will be deleted before returning. std::unique_ptr cntl_guard(cntl); std::unique_ptr response_guard(response); // this channel can be used by other now. ch_ptr->Push(ch_ctx); if (cntl->Failed()) { LOG(FATAL) << "Fail to send SendVar: " << var_h->name() << ", error text: " << cntl->ErrorText(); var_h->Finish(false); cls->DecreaseReqCount(); return; } var_h->Finish(true); cls->DecreaseReqCount(); VLOG(4) << "HandleSendResponse from: " << cntl->remote_side() << ", varname: " << var_h->name() << ", latency: " << cntl->latency_us() << "us"; VLOG(4) << "Finish HandleSendResponse"; } VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch_ptr = GetChannel(ep_val); const std::string method = "SendRPC"; VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); framework::AsyncIO([=] { auto ch_ctx = ch_ptr->Pop(); brpc::Controller* cntl = new brpc::Controller(); sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); cntl->set_timeout_ms(time_out); auto* var = p_scope->FindVar(var_name_val); sendrecv::VariableMessage request; distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request, &cntl->request_attachment(), "", false, trainer_id_); google::protobuf::Closure* done = brpc::NewCallback( &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); platform::RecordRPCEvent record_event(method, p_ctx); ch_ctx->stub->SendVariable(cntl, &request, response, done); if (UNLIKELY(platform::IsProfileEnabled())) { var_h->Wait(); } }); req_count_++; return var_h; } void HandleFetchBarrierResponse(brpc::Controller* cntl, sendrecv::VariableMessage* response, VarHandlePtr var_h, ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx, BRPCClient* cls) { // std::unique_ptr makes sure cntl/response will be deleted before returning. std::unique_ptr cntl_guard(cntl); std::unique_ptr response_guard(response); // this channel can be used other now. ch_ptr->Push(ch_ctx); if (cntl->Failed()) { LOG(FATAL) << "Fail to get HandleFetchBarrierResponse: " << var_h->name() << ", error text: " << cntl->ErrorText(); var_h->Finish(false); cls->DecreaseReqCount(); return; } var_h->Finish(true); cls->DecreaseReqCount(); VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side() << ", varname: " << var_h->name() << ", latency: " << cntl->latency_us() << "us"; VLOG(4) << "Finish HandleFetchBarrierResponse"; } void HandleGetResponse(brpc::Controller* cntl, sendrecv::VariableMessage* response, VarHandlePtr var_h, ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx, BRPCClient* cls) { // std::unique_ptr makes sure cntl/response will be deleted before returning. std::unique_ptr cntl_guard(cntl); std::unique_ptr response_guard(response); // this channel can be used other now. ch_ptr->Push(ch_ctx); if (cntl->Failed()) { LOG(FATAL) << "Fail to GetVar: " << var_h->name() << ", error text: " << cntl->ErrorText(); cls->DecreaseReqCount(); var_h->Finish(false); return; } VLOG(4) << "HandleGetResponse from: " << cntl->remote_side() << ", varname: " << var_h->name() << ", latency: " << cntl->latency_us() << "us"; framework::Variable* outvar = nullptr; int trainer_id; distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(), *var_h->ctx(), var_h->scope(), &outvar, &trainer_id); VLOG(4) << "Finish HandleGetResponse"; cls->DecreaseReqCount(); var_h->Finish(true); } VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, const std::string& method_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch_ptr = GetChannel(ep_val); const std::string method = "GetRPC"; VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); framework::AsyncIO([=] { auto ch_ctx = ch_ptr->Pop(); brpc::Controller* cntl = new brpc::Controller(); sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); cntl->set_timeout_ms(time_out); sendrecv::VariableMessage req; req.set_varname(var_name_val); req.set_trainer_id(trainer_id_); google::protobuf::Closure* done = brpc::NewCallback( &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); platform::RecordRPCEvent record_event(method, p_ctx); if (method_name == "GetMonomerVariable") { ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done); } else { ch_ctx->stub->GetVariable(cntl, &req, response, done); } if (UNLIKELY(platform::IsProfileEnabled())) { var_h->Wait(); } }); req_count_++; return var_h; } VarHandlePtr BRPCClient::AsyncGetMonomerVariable( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out); } VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep, const std::string& var_name, int64_t time_out) { return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out); } VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out); } VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, const std::string& table_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch_ptr = GetChannel(ep_val); const std::string method = "PrefetchRPC"; VarHandlePtr var_h( new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); framework::AsyncIO([=] { auto ch_ctx = ch_ptr->Pop(); brpc::Controller* cntl = new brpc::Controller(); sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); cntl->set_timeout_ms(time_out); auto* var = p_scope->FindVar(in_var_name_val); sendrecv::VariableMessage req; distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req, &cntl->request_attachment(), out_var_name_val, false, 0, table_name_val); platform::RecordRPCEvent record_event(method, p_ctx); google::protobuf::Closure* done = brpc::NewCallback( &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); ch_ctx->stub->PrefetchVariable(cntl, &req, response, done); if (UNLIKELY(platform::IsProfileEnabled())) { var_h->Wait(); } }); req_count_++; return var_h; } VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE, time_out); } VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { auto ch_ptr = GetChannel(ep); auto ch_ctx = ch_ptr->Pop(); brpc::Controller* cntl = new brpc::Controller(); sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); cntl->set_timeout_ms(time_out); sendrecv::VariableMessage req; req.set_varname(FETCH_BARRIER_MESSAGE); const std::string method = "FetchBarrierRPC"; // var handle VarHandlePtr var_h( new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); platform::RecordRPCEvent record_event(method, nullptr); google::protobuf::Closure* done = brpc::NewCallback( &HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); ch_ctx->stub->GetVariable(cntl, &req, response, done); req_count_++; if (UNLIKELY(platform::IsProfileEnabled())) { var_h->Wait(); } return var_h; } bool BRPCClient::Wait() { VLOG(9) << "begin to brpcclient wait"; { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); } VLOG(9) << "end to brpcclient wait"; return true; } ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { VLOG(4) << "begin to GetChannel:" << ep; { std::lock_guard guard(chan_mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { VLOG(4) << "end to GetChannel:" << ep; return it->second; } } ChannelQueuePtr q(new framework::BlockingQueue()); brpc::ChannelOptions options; #ifdef PADDLE_WITH_BRPC_RDMA options.use_rdma = true; #endif options.protocol = "baidu_std"; // don't use pooled type. the server can't afford that. options.connection_type = "single"; options.connect_timeout_ms = 1000; options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/; options.max_retry = FLAGS_max_retry; VLOG(1) << "create " << brpc_channel_num_per_server_ << " brpc channels to pserver:" << ep; for (int i = 0; i < brpc_channel_num_per_server_; ++i) { std::shared_ptr c(new ChannelContext()); if (c->channel.Init(ep.c_str(), &options) != 0) { LOG(FATAL) << "Fail to initialize channel"; return nullptr; } c->stub.reset(new sendrecv::SendRecvService_Stub( static_cast(&c->channel))); q->Push(c); } { std::lock_guard guard(chan_mutex_); channels_[ep] = q; } VLOG(4) << "end to GetChannel:" << ep; return q; } VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out); } void BRPCClient::SendComplete() { for (auto& kv : channels_) { AsyncSendComplete(kv.first); } } VarHandlePtr BRPCClient::AsyncSendVarMessage( const std::string& ep, const std::string& method_name, const sendrecv::VariableMessage& req, int64_t time_out) { auto ch_ptr = GetChannel(ep); auto ch_ctx = ch_ptr->Pop(); brpc::Controller* cntl = new brpc::Controller(); sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); cntl->set_timeout_ms(time_out); platform::RecordRPCEvent record_event(method_name, nullptr); VarHandlePtr var_h( new VarHandle(ep, method_name, req.varname(), nullptr, nullptr)); google::protobuf::Closure* done = brpc::NewCallback( &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); if (method_name == "CheckPointNotifyRPC") { ch_ctx->stub->CheckpointNotify(cntl, &req, response, done); } else if (method_name == "GetMonomerBarrier") { ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done); } else { ch_ctx->stub->SendVariable(cntl, &req, response, done); } req_count_++; if (UNLIKELY(platform::IsProfileEnabled())) { var_h->Wait(); } return var_h; } VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, const std::string& method_name, const std::string& message, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(message); return AsyncSendVarMessage(ep, method_name, req, time_out); } VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_out_varname(dir); return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out); } } // namespace distributed } // namespace operators } // namespace paddle