From 2f41f389228e48bab4b677a7f14248c39c67782f Mon Sep 17 00:00:00 2001 From: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com> Date: Thu, 31 Mar 2022 22:31:14 +0800 Subject: [PATCH] heter & multi-cloud brpc communication (#40965) * back fl * delete ssl cert * . * make warning * . * unittest paral degree * solve unittest * heter & multi cloud commm ready * . * . --- .../distributed/ps/service/CMakeLists.txt | 2 +- .../distributed/ps/service/brpc_ps_client.cc | 2 + .../distributed/ps/service/heter_client.cc | 289 ++++++++--- .../distributed/ps/service/heter_client.h | 129 ++++- .../distributed/ps/service/heter_server.cc | 236 ++++++++- .../distributed/ps/service/heter_server.h | 474 +++++++++++++----- .../distributed/ps/service/sendrecv.proto | 13 + paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/operators/pscore/CMakeLists.txt | 7 +- .../pscore/heter_cloud_comm_cpu_test.cc | 247 +++++++++ .../pscore/heter_listen_and_serv_op.cc | 40 +- .../pscore/heter_listen_and_serv_op.h | 8 +- .../pscore/heter_listen_and_server_test.cc | 30 +- .../operators/pscore/heter_server_test.cc | 49 +- .../pscore/send_and_recv_op_cpu_test.cc | 15 +- .../pscore/send_and_recv_op_gpu_test.cc | 19 +- tools/parallel_UT_rule.py | 1 + 17 files changed, 1236 insertions(+), 329 deletions(-) mode change 100644 => 100755 paddle/fluid/distributed/ps/service/CMakeLists.txt mode change 100644 => 100755 paddle/fluid/distributed/ps/service/brpc_ps_client.cc mode change 100644 => 100755 paddle/fluid/distributed/ps/service/heter_client.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/sendrecv.proto mode change 100644 => 100755 paddle/fluid/operators/pscore/CMakeLists.txt create mode 100644 paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc mode change 100644 => 100755 paddle/fluid/operators/pscore/heter_listen_and_serv_op.h mode change 100644 => 100755 paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc diff --git a/paddle/fluid/distributed/ps/service/CMakeLists.txt b/paddle/fluid/distributed/ps/service/CMakeLists.txt old mode 100644 new mode 100755 index ab6c2e2600..b8de291072 --- a/paddle/fluid/distributed/ps/service/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/service/CMakeLists.txt @@ -39,8 +39,8 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) cc_library(communicator SRCS communicator/communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS}) cc_library(ps_service SRCS ps_service/service.cc DEPS communicator client server boost ${RPC_DEPS}) -cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_server SRCS heter_server.cc DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) set_source_files_properties(ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(graph_py_service SRCS ps_service/graph_py_service.cc DEPS ps_service) diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc old mode 100644 new mode 100755 index 9674717ffc..d7d41d6bbd --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -55,6 +55,8 @@ DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); DEFINE_int32(pserver_sparse_table_shard_num, 1000, "sparse table shard for save & load"); +DEFINE_int32(heter_world_size, 100, "group size"); // 可配置 + namespace paddle { namespace framework { class Scope; diff --git a/paddle/fluid/distributed/ps/service/heter_client.cc b/paddle/fluid/distributed/ps/service/heter_client.cc index d6287cda6d..4ca25dac82 100644 --- a/paddle/fluid/distributed/ps/service/heter_client.cc +++ b/paddle/fluid/distributed/ps/service/heter_client.cc @@ -13,18 +13,14 @@ // limitations under the License. #include "paddle/fluid/distributed/ps/service/heter_client.h" + #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/string/split.h" - -DECLARE_int32(rpc_deadline); -DECLARE_int32(pserver_timeout_ms); namespace paddle { namespace distributed { -std::shared_ptr HeterClient::s_instance_ = NULL; -bool HeterClient::is_initialized_ = false; +std::shared_ptr HeterClient::s_instance_ = nullptr; int GetMicroId(const platform::DeviceContext& ctx, const framework::Scope* scope) { @@ -54,58 +50,21 @@ int GetMicroId(const platform::DeviceContext& ctx, return micro_id; } -void HeterClient::MainThread() { - while (running_) { - RpcProfilerControl(); - } -} - void HeterClient::Stop() { - running_ = false; - if (!is_initialized_) { - VLOG(3) << "HeterClient is not inited, do nothing"; - } else { - if (main_thread_) { - auto status = StopHeterWorker(); - status.wait(); - main_thread_->join(); - main_thread_.reset(nullptr); - } - VLOG(3) << "HeterClient Stop Done"; - } -} - -void HeterClient::FinalizeWorker() { - running_ = false; - if (!is_initialized_) { - VLOG(3) << "HeterClient is not inited, do nothing"; - } else { - if (main_thread_) { - main_thread_->join(); - main_thread_.reset(nullptr); - } - VLOG(3) << "HeterClient Stop Done"; - } + auto status = StopHeterWorker(); + status.wait(); } std::future HeterClient::StopHeterWorker() { return SendCmd(-1, PS_STOP_SERVER, {}); } -void HeterClient::RpcProfilerControl() { - if (trainer_id_ == 0) { - if (!do_server_profiler_ && platform::IsProfileEnabled()) { - // send profiler start flag - do_server_profiler_ = true; - auto start_status = StartProfiler(); - start_status.wait(); - } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { - // send profiler end flag - auto stop_status = StopProfiler(); - stop_status.wait(); - do_server_profiler_ = false; - } - } +std::future HeterClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); +} + +std::future HeterClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); } void HeterClient::CreateClient2XpuConnection() { @@ -156,27 +115,24 @@ void HeterClient::SendAndRecvAsync( 1); const platform::DeviceContext* p_ctx = &ctx; const framework::Scope* p_scope = &scope; - const std::string message_name_val = message_name; const std::vector send_var_name_val = send_var_name; const std::vector recv_var_name_val = recv_var_name; - VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " - << message_name_val; + VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name; brpc::Channel* channel = nullptr; distributed::MultiVarMsg request; - OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) { + OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { auto* closure = reinterpret_cast(done); PADDLE_ENFORCE_NE( closure->cntl.Failed(), true, platform::errors::Unimplemented( "HeterClient::SendAndRecv meets brpc error, error message is %s", closure->cntl.ErrorText())); - VLOG(4) << "call heter_worker success"; }); closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); auto& request_io_buffer = closure->cntl.request_attachment(); distributed::SerializeToMultiVarMsgAndIOBuf( - message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, + message_name, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, &request, &request_io_buffer); int micro_id = GetMicroId(ctx, p_scope); @@ -188,6 +144,19 @@ void HeterClient::SendAndRecvAsync( } else if (mode == "backward") { int num = minibatch_id % previous_xpu_channels_.size(); channel = previous_xpu_channels_[num].get(); + } else if (mode == "send_to_switch") { + VLOG(4) << "calling switch service"; + // auto promise = std::make_shared>(); + // closure->add_promise(promise); + // std::future fut = promise->get_future(); + // int idx = 1; // for test + // LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size(); + // channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op + // ::paddle::distributed::PsService_Stub stub(channel); + // stub.SendToSwitch(&closure->cntl, &request, &closure->response, + // closure); fut.wait(); + VLOG(4) << "calling switch service done"; + return; } ::paddle::distributed::PsService_Stub stub(channel); stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response, @@ -229,13 +198,209 @@ std::future HeterClient::SendCmd( return fut; } -std::future HeterClient::StartProfiler() { - return SendCmd(-1, PS_START_PROFILER, {}); +int HeterClient::Send(const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& message_name, + const std::vector& send_var_names) { + const framework::Scope* p_scope = &scope; // 注意是 const + OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { + auto* closure = reinterpret_cast(done); + int ret = 0; + closure->set_promise_value(ret); + if (closure->cntl.Failed()) { + PADDLE_ENFORCE_NE( + closure->cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendToSwitch meets brpc error, error message is %s", + closure->cntl.ErrorText())); + } + }); + + closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + auto& request_io_buffer = closure->cntl.request_attachment(); + + distributed::MultiVarMsg request; + // 1. set req message_name(string) + request.set_message_name(message_name); + + // 2. set req send_var_names() + for (auto& send_var_name : send_var_names) { + request.add_send_var_names(send_var_name); + } + + // 3. set req var_messages() + for (auto& send_var_name : send_var_names) { + auto* send_var_msg = request.add_var_messages(); + send_var_msg->set_varname(send_var_name); + framework::Variable* var = p_scope->FindVar(send_var_name); + butil::IOBuf temp_iobuf; + if (var->IsType()) { + SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf); + } else if (var->IsType()) { + SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf); + } + request_io_buffer.append(temp_iobuf); + } + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + if (send_switch_channels_.empty()) { + LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]"; + if (xpu_channels_.empty()) { + LOG(ERROR) << "xpu_channels_ is null"; + } + send_switch_channels_.push_back(xpu_channels_[0]); + } + brpc::Channel* channel = send_switch_channels_[0].get(); + // brpc::Channel* channel = xpu_channels_[0].get(); + ::paddle::distributed::PsService_Stub stub(channel); + stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure); + + VLOG(4) << "waiting SendToSwitch response result......"; + fut.wait(); + VLOG(4) << "Send done"; + return 0; } -std::future HeterClient::StopProfiler() { - return SendCmd(-1, PS_STOP_PROFILER, {}); +int HeterClient::Send(int group_id, const std::vector& var_names, + const std::vector& vars_len, void* data_ptr, + int64_t data_size) { + OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { + auto* closure = reinterpret_cast(done); + int ret = 0; + closure->set_promise_value(ret); + if (closure->cntl.Failed()) { + LOG(ERROR) << "Send meets brpc error, err msg is %s" + << closure->cntl.ErrorText(); + } + }); + distributed::MultiVarMsg request; + closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + std::string message_name = "send and save"; + request.set_message_name(message_name); + request.set_group_id(group_id); + for (auto& send_var_name : var_names) { + request.add_send_var_names(send_var_name); + } + for (auto var_len : vars_len) { + request.add_vars_len(var_len); + } + auto& request_buffer = closure->cntl.request_attachment(); + request_buffer.append(reinterpret_cast(data_ptr), + data_size * sizeof(float)); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + if (send_switch_channels_.empty()) { + LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]"; + if (xpu_channels_.empty()) { + LOG(ERROR) << "xpu_channels_ is null"; + } + send_switch_channels_.push_back(xpu_channels_[0]); + } + brpc::Channel* channel = send_switch_channels_[0].get(); + ::paddle::distributed::PsService_Stub stub(channel); + stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure); + fut.wait(); + return 0; } -} // end namespace distributed +int HeterClient::Recv(const platform::DeviceContext& ctx, + framework::Scope& recv_scope, // NOLINT + const std::string& message_name, + const std::vector& recv_var_names) { + OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { + auto* closure = reinterpret_cast(done); + VLOG(4) << "Recv service call done"; + int ret = 0; + closure->set_promise_value(ret); + if (closure->cntl.Failed()) { + VLOG(4) << "HeterClient::RecvFromSwitch meets " + "brpc error, error message is %s" + << closure->cntl.ErrorText(); + } + }); + + closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + + distributed::MultiVarMsg request; + // 1. set req message_name(string) + request.set_message_name(message_name); + + // 2. set req recv_var_names() + for (auto& recv_var_name : recv_var_names) { + request.add_recv_var_names(recv_var_name); + } + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + if (recv_switch_channels_.empty()) { + LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]"; + if (xpu_channels_.size() < 2) { + LOG(ERROR) << "xpu_channels_ is null"; + } + recv_switch_channels_.push_back(xpu_channels_[1]); + } + brpc::Channel* channel = recv_switch_channels_[0].get(); + ::paddle::distributed::PsService_Stub stub(channel); + stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure); + fut.wait(); + VLOG(4) << "RecvFromSwitch done"; + // save in worker + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::CPUPlace cpu_place; + auto& cpu_dev_ctx = *pool.Get(cpu_place); + auto& res_io_buffer = closure->cntl.response_attachment(); + VLOG(4) << "entering DeserializeFromMultiVarMsgAndIOBuf"; + distributed::DeserializeFromMultiVarMsgAndIOBuf( + closure->response, &res_io_buffer, cpu_dev_ctx, &recv_scope); + VLOG(4) << "Recv done"; + return 0; +} + +int HeterClient::Recv(int group_id, const std::vector& var_names, + void* data_ptr, int64_t data_size) { + OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) { + auto* closure = reinterpret_cast(done); + int ret = 0; + closure->set_promise_value(ret); + if (closure->cntl.Failed()) { + LOG(ERROR) << "Recv meets brpc error, err msg is %s" + << closure->cntl.ErrorText(); + } + }); + closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + + distributed::MultiVarMsg request; + std::string message_name = "query and recv"; + request.set_message_name(message_name); + request.set_group_id(group_id); + + for (auto& recv_var_name : var_names) { + request.add_recv_var_names(recv_var_name); + } + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + if (recv_switch_channels_.empty()) { + LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]"; + if (xpu_channels_.size() < 2) { + LOG(ERROR) << "xpu_channels_ is null"; + } + recv_switch_channels_.push_back(xpu_channels_[1]); + } + brpc::Channel* channel = recv_switch_channels_[0].get(); + ::paddle::distributed::PsService_Stub stub(channel); + stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure); + fut.wait(); + VLOG(4) << "RecvFromSwitch done"; + // save in worker + auto& res_io_buffer = closure->cntl.response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + io_buffer_itr.copy_and_forward(reinterpret_cast(data_ptr), + data_size * sizeof(float)); + VLOG(4) << "Recv done"; + return 0; +} +} // namespace distributed } // end namespace paddle diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h old mode 100644 new mode 100755 index 4f27ef75ea..006f87ddf5 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -32,13 +32,14 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/string/split.h" namespace paddle { namespace framework { class Scope; } // namespace framework } // namespace paddle - +DECLARE_int32(pserver_timeout_ms); namespace paddle { namespace distributed { @@ -51,24 +52,72 @@ class OnHeterRpcDone : public google::protobuf::Closure { public: explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {} virtual ~OnHeterRpcDone() {} - void Run() { - std::unique_ptr self_guard(this); - handler_(this); + void Run() { handler_(this); } + + void add_promise(std::shared_ptr>& promise) { // NOLINT + _promises.push_back(promise); } + void set_promise_value(int value) { + for (auto& promise : _promises) { + promise->set_value(value); + } + } + int CheckResponse() { return 0; } + std::vector>> _promises; HeterRpcCallbackFunc handler_; + + MultiVariableMessage request; MultiVariableMessage response; + + PsResponseMessage ps_response; + brpc::Controller cntl; + // PsRequestMessage *request(size_t i) { return &_requests[i]; } + // PsResponseMessage *response(size_t i) { return &_responses[i]; } + // std::vector _requests; + // std::vector _responses; + // std::vector> _cntls; }; class HeterClient { public: virtual ~HeterClient() {} - HeterClient() { - running_ = true; - main_thread_.reset( - new std::thread(std::bind(&HeterClient::MainThread, this))); + void InitClientChannels(bool need_encrypt, + const std::vector& node_list, + int32_t peer_role) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.connection_type = "single"; + options.timeout_ms = FLAGS_pserver_timeout_ms; + std::vector>* client_channels = nullptr; + if (peer_role == PEER_ROLE_IS_SWITCH) { + options.ssl_options.enable = need_encrypt; + client_channels = &peer_switch_channels_; + } else if (peer_role == PEER_ROLE_IS_WORKER) { + client_channels = &peer_worker_channels_; + } else { + LOG(ERROR) << "init switch client failed, peer_role not valid"; + } + (*client_channels).resize(node_list.size()); + for (size_t i = 0; i < node_list.size(); ++i) { + (*client_channels)[i].reset(new brpc::Channel()); + if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) != + 0) { + VLOG(0) << "client channel init failed! try again"; + auto ip_port = paddle::string::Split(node_list[i], ':'); + std::string ip = ip_port[0]; + int port = std::stoi(ip_port[1]); + std::string int_ip_port = GetIntTypeEndpoint(ip, port); + if ((*client_channels)[i]->Init(int_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "client channel init failed! peer ip_port = " + << int_ip_port; + } + } + } + VLOG(4) << "InitClientChannels success"; } void CreateClient2XpuConnection(); @@ -80,14 +129,28 @@ class HeterClient { const std::vector& recv_var_name, const std::string& mode = "forward"); + int Send(int group_id, const std::vector& var_names, + const std::vector& vars_len, void* data_ptr, int64_t data_size); + + int Send(const platform::DeviceContext& ctx, const framework::Scope& scope, + const std::string& message_name, + const std::vector& send_var_names); + + int Recv(int group_id, const std::vector& var_names, + void* data_ptr, int64_t data_size); + + int Recv(const platform::DeviceContext& ctx, + framework::Scope& recv_scope, // NOLINT + const std::string& message_name, + const std::vector& recv_var_names); + // HeterClient singleton static std::shared_ptr GetInstance( const std::vector& endpoint, const std::vector& previous_endpoint, const int& trainer_id) { if (NULL == s_instance_) { - is_initialized_ = true; - s_instance_.reset(new paddle::distributed::HeterClient()); + s_instance_.reset(new HeterClient()); s_instance_->SetXpuList(endpoint); s_instance_->SetPreviousXpuList(previous_endpoint); s_instance_->SetTrainerID(trainer_id); @@ -96,13 +159,29 @@ class HeterClient { return s_instance_; } - void Stop(); + // switch client singleton + static HeterClient& GetSwitchInstance( + const std::vector& peer_endpoints, int32_t peer_role) { + static HeterClient switch_s_instance_; + if (peer_endpoints.empty()) { + VLOG(4) << "init switch client failed, null peer_endpoints"; + } + VLOG(4) << "peer role is: " << peer_role + << ", addr is: " << peer_endpoints[0]; + switch_s_instance_.SetPeerSwitchList(peer_endpoints); + switch_s_instance_.InitClientChannels(false, peer_endpoints, peer_role); + return switch_s_instance_; + } - void FinalizeWorker(); + void SetPeerSwitchList(const std::vector& peer_endpoints) { + peer_switch_list_ = peer_endpoints; + } - void MainThread(); + void SetPeerWorkerList(const std::vector& worker_endpoints) { + peer_worker_list_ = worker_endpoints; + } - void RpcProfilerControl(); + void Stop(); std::future SendCmd(uint32_t table_id, int cmd_id, const std::vector& params); @@ -124,20 +203,32 @@ class HeterClient { void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } + public: + std::vector send_switch_list_; + std::vector recv_switch_list_; + + std::vector peer_switch_list_; + std::vector peer_worker_list_; + std::vector> send_switch_channels_; + std::vector> recv_switch_channels_; + + std::vector> peer_switch_channels_; + std::vector> peer_worker_channels_; + private: + HeterClient() {} + HeterClient& operator=(const HeterClient&); + HeterClient(const HeterClient&); + static std::shared_ptr s_instance_; - static bool is_initialized_; - std::unique_ptr main_thread_{nullptr}; std::vector> xpu_channels_; std::vector> previous_xpu_channels_; - DISABLE_COPY_AND_ASSIGN(HeterClient); + // DISABLE_COPY_AND_ASSIGN(HeterClient); std::vector xpu_list_; std::vector previous_xpu_list_; - bool running_ = false; int trainer_id_; - bool do_server_profiler_ = false; }; } // end namespace distributed diff --git a/paddle/fluid/distributed/ps/service/heter_server.cc b/paddle/fluid/distributed/ps/service/heter_server.cc index 01afed3f12..e21bf093f1 100644 --- a/paddle/fluid/distributed/ps/service/heter_server.cc +++ b/paddle/fluid/distributed/ps/service/heter_server.cc @@ -13,21 +13,28 @@ // limitations under the License. #include "paddle/fluid/distributed/ps/service/heter_server.h" + #include "paddle/fluid/string/split.h" namespace paddle { namespace distributed { +// DEFINE_string(cert_path, "./cert.pem", "cert.pem path"); +// DEFINE_string(key_path, "./key.pem", "key.pem path"); -std::shared_ptr HeterServer::s_instance_ = NULL; +std::shared_ptr HeterServer::s_instance_ = nullptr; void HeterServer::RegisterServiceHandler(std::string message_name, HeterServiceHandler func) { service_.RegisterServiceHandler(message_name, func); } -void HeterServer::StartHeterService() { +void HeterServer::StartHeterService(bool neeed_encrypt) { server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); brpc::ServerOptions options; + if (neeed_encrypt) { + options.ssl_options.default_cert.certificate = "/cert.pem"; + options.ssl_options.default_cert.private_key = "/key.pem"; + } if (server_.Start(endpoint_.c_str(), &options) != 0) { VLOG(0) << "HeterServer start fail. Try again."; auto ip_port = paddle::string::Split(endpoint_, ':'); @@ -47,16 +54,50 @@ void HeterServer::StartHeterService() { ready_ = 1; } condition_ready_.notify_all(); + VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_; std::unique_lock running_lock(mutex_); cv_.wait(running_lock, [&] { - VLOG(1) << "Heter Server is Stop? " << stoped_; + VLOG(4) << "Heter Server is Stop? " << stoped_; return stoped_; }); + VLOG(4) << "start service done"; } -void HeterServer::SetEndPoint(const std::string& endpoint) { - endpoint_ = endpoint; - service_.SetEndpoint(endpoint); +void HeterServer::StartHeterInterService(bool neeed_encrypt) { + server_inter_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + if (neeed_encrypt) { + options.ssl_options.default_cert.certificate = "/cert.pem"; + options.ssl_options.default_cert.private_key = "/key.pem"; + } + if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) { + VLOG(4) << "switch inter server start fail. Try again."; + auto ip_port = paddle::string::Split(endpoint_inter_, ':'); + std::string ip = ip_port[0]; + int port = std::stoi(ip_port[1]); + std::string int_ip_port = GetIntTypeEndpoint(ip, port); + if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) { + LOG(ERROR) << "switch inter server start failed, ip_port= " + << int_ip_port; + } + } else { + VLOG(4) << "switch inter server server start success! listen on " + << endpoint_inter_; + } + + { + std::lock_guard lock(this->mutex_ready_); + stoped_ = false; + ready_ = 1; + } + condition_ready_.notify_all(); + VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_; + std::unique_lock running_lock(mutex_); + cv_.wait(running_lock, [&] { + VLOG(4) << "Heter Server is Stop? " << stoped_; + return stoped_; + }); + VLOG(4) << "start service done"; } void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); } @@ -64,35 +105,180 @@ void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); } void HeterServer::WaitServerReady() { std::unique_lock lock(this->mutex_ready_); condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); + while (!this->ready_) { + sleep(1); + } } -int32_t HeterService::stop_profiler(const PsRequestMessage& request, - PsResponseMessage& response, - brpc::Controller* cntl) { - platform::DisableProfiler( - platform::EventSortingKey::kDefault, - string::Sprintf("heter_worker_%s_profile", endpoint_)); +int SendAndRecvVariableHandler::SaveInSwitchWithShard( + const MultiVarMsg* request, PsResponseMessage* response, + brpc::Controller* cntl) { + VLOG(4) << "entering SaveInSwitchWithShard"; + int32_t group_id = request->group_id(); + auto& local_shard = _local_shards[group_id]; + auto& request_io_buffer = cntl->request_attachment(); + butil::IOBufBytesIterator io_buffer_itr(request_io_buffer); + for (int idx = 0; idx < request->send_var_names_size(); idx++) { + const auto& var_name = request->send_var_names(idx); + const auto& var_len = request->vars_len(idx); + auto itr = local_shard.find(var_name); + if (itr != local_shard.end()) { + LOG(INFO) << "var: " << var_name << "has not been consumed!" + << "check again"; + WaitForVarsConsumed(group_id, var_name); + } + auto& value = local_shard[var_name]; + value.resize(var_len); + io_buffer_itr.copy_and_forward(reinterpret_cast(value.data()), + var_len * sizeof(float)); + VLOG(4) << "saved data in shards: "; + for (uint32_t i = 0; i < local_shard[var_name].size(); i++) { + VLOG(4) << *(local_shard[var_name].data() + i); + } + } + VLOG(4) << "SaveInSwitchWithShard success"; return 0; } -int32_t HeterService::start_profiler(const PsRequestMessage& request, - PsResponseMessage& response, - brpc::Controller* cntl) { - platform::EnableProfiler(platform::ProfilerState::kAll); +int SendAndRecvVariableHandler::QueryInSwitchWithShard( + const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) { + VLOG(4) << "entering QueryInSwitchWithShard"; + int32_t group_id = request->group_id(); + VLOG(4) << "group id: " << group_id; + auto& local_shard = _local_shards[group_id]; + auto& response_io_buffer = cntl->response_attachment(); + auto req_var_nums = request->recv_var_names_size(); + std::vector req_var_names(req_var_nums); + for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) { + req_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto msg_name = request->message_name(); + response->set_message_name(msg_name); + + for (auto& req_var_name : req_var_names) { + VLOG(4) << "req var name: " << req_var_name; + response->add_send_var_names(req_var_name); + auto itr = local_shard.find(req_var_name); + if (itr == local_shard.end()) { + LOG(INFO) << "var: " << req_var_name << " not found in shards"; + WaitForVarsProduced(group_id, req_var_name); + } + LOG(INFO) << "var: " << req_var_name << " found in shards"; + itr = local_shard.find(req_var_name); + auto& value = itr.value(); + response_io_buffer.append(value.data(), value.size() * sizeof(float)); + value.resize(0); // 标记位 + } + VLOG(4) << "heter server QueryInSwitchWithShard done"; return 0; } -int32_t HeterService::stop_heter_worker(const PsRequestMessage& request, - PsResponseMessage& response, - brpc::Controller* cntl) { - auto client_id = request.client_id(); - stop_cpu_worker_set_.insert(client_id); - if (stop_cpu_worker_set_.size() == fan_in_) { - is_exit_ = true; - VLOG(3) << "Stop heter Service done."; +int SendAndRecvVariableHandler::SaveInSwitchWithScope( + const MultiVarMsg* request, PsResponseMessage* response, + brpc::Controller* cntl) { + VLOG(4) << "entering SaveInSwitchWithScope"; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::CPUPlace cpu_place; + auto& cpu_dev_ctx = *pool.Get(cpu_place); + auto message_name = request->message_name(); + VLOG(4) << "message_name in heter server: " << message_name; + std::unique_lock lk(scope_mutex_); + auto local_scope = local_scope_ptr.get(); + if (!local_scope) { + LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope"; + } + for (int idx = 0; idx < request->send_var_names_size(); idx++) { + const auto& msg = request->var_messages(idx); + std::string var_name = msg.varname(); + auto* var_exist_ptr = local_scope->FindVar(var_name); + if (!var_exist_ptr) { + VLOG(4) << "not find var: " << var_name << " in local_scope"; + } + vars_table[var_name] += 1; + VLOG(4) << "saved var_name: " << var_name + << ", cnt = " << vars_table[var_name]; + } + auto& request_io_buffer = cntl->request_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer, + cpu_dev_ctx, local_scope); + lk.unlock(); + while (true) { + int ret = 0; + for (int idx = 0; idx < request->send_var_names_size(); idx++) { + ret |= vars_table[request->var_messages(idx).varname()]; + } + if (!ret) { + VLOG(4) << "all saved vars consumed"; + break; + } + VLOG(4) << "waiting consume result......"; + sleep(1); } + VLOG(4) << "SaveInSwitchWithScope success"; return 0; } +int SendAndRecvVariableHandler::QueryInSwitchWithScope( + const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) { + VLOG(4) << "entering QueryInSwitchWithScope"; + auto local_scope = local_scope_ptr.get(); + if (!local_scope) { + LOG(INFO) << "local_scope is null"; + } + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::CPUPlace cpu_place; + auto& cpu_dev_ctx = *pool.Get(cpu_place); + + // get req message_name & req_var_names + auto msg_name = request->message_name(); + auto req_var_nums = request->recv_var_names_size(); + std::vector req_var_names(req_var_nums); + for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) { + req_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto& response_io_buffer = cntl->response_attachment(); + + // 1. fill message_name(string) + response->set_message_name(msg_name); + + // 2. fill var_names(string) + for (auto& req_var_name : req_var_names) { + response->add_send_var_names(req_var_name); + } + + // 3. fill var_messages(VarMessage) + for (auto& req_var_name : req_var_names) { + LOG(INFO) << "query var_name: " << req_var_name; + auto* send_var_msg = response->add_var_messages(); + send_var_msg->set_varname(req_var_name); + + framework::Variable* var_ptr; + while (true) { + var_ptr = local_scope->FindVar(req_var_name); + if (!var_ptr) { + LOG(INFO) << "local_scope not find var: " << req_var_name; + } else { + break; + } + sleep(1); + } + butil::IOBuf temp_iobuf; + if (var_ptr->IsType()) { + SerializeLodTensor(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf); + } else if (var_ptr->IsType()) { + SerializeSelectedRows(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf); + } + response_io_buffer.append(temp_iobuf); + } + for (auto& req_var_name : req_var_names) { + std::unique_lock lk(scope_mutex_); + vars_table[req_var_name] -= 1; + VLOG(4) << "remained var: " << req_var_name + << ", cnt = " << vars_table[req_var_name]; + lk.unlock(); + } + VLOG(4) << "heter server QueryInSwitchWithScope done"; + return 0; +} } // end namespace distributed -} // end namespace paddle +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/heter_server.h b/paddle/fluid/distributed/ps/service/heter_server.h index a14fb5f6cc..624e76112c 100644 --- a/paddle/fluid/distributed/ps/service/heter_server.h +++ b/paddle/fluid/distributed/ps/service/heter_server.h @@ -22,11 +22,14 @@ limitations under the License. */ #include #include #include + #include "brpc/channel.h" #include "brpc/controller.h" #include "brpc/server.h" #include "paddle/fluid/distributed/ps/service/brpc_utils.h" +#include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/program_desc.h" @@ -51,108 +54,37 @@ class Scope; } // namespace paddle DECLARE_double(eager_delete_tensor_gb); +DECLARE_int32(pserver_timeout_ms); +DECLARE_int32(heter_world_size); namespace paddle { namespace distributed { -using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; -using VarMsg = ::paddle::distributed::VariableMessage; - -class HeterService; +using MultiVarMsg = MultiVariableMessage; +using VarMsg = VariableMessage; -typedef int32_t (HeterService::*serviceHandlerFunc)( +using serviceHandler = std::function; +using HeterServiceHandler = + std::function; -typedef std::function HeterRpcCallbackFunc; -typedef std::function - HeterServiceHandler; +using HeterRpcCallbackFunc = std::function; -class HeterService : public ::paddle::distributed::PsService { +class ServiceHandlerBase { public: - HeterService() { - _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; - _service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler; - } + ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {} - virtual ~HeterService() {} + virtual ~ServiceHandlerBase() {} - virtual void service(::google::protobuf::RpcController* controller, - const PsRequestMessage* request, - PsResponseMessage* response, - ::google::protobuf::Closure* done) { - brpc::ClosureGuard done_guard(done); - std::string log_label("ReceiveCmd-"); - - response->set_err_code(0); - response->set_err_msg(""); - brpc::Controller* cntl = static_cast(controller); - auto itr = _service_handler_map.find(request->cmd_id()); - if (itr == _service_handler_map.end()) { - std::string err_msg( - "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); - err_msg.append(std::to_string(request->cmd_id())); - return; - } - serviceHandlerFunc handler_func = itr->second; - int service_ret = (this->*handler_func)(*request, *response, cntl); - if (service_ret != 0) { - response->set_err_code(service_ret); - response->set_err_msg("server internal error"); - } - } - - void SendAndRecvVariable(::google::protobuf::RpcController* controller, - const MultiVarMsg* request, MultiVarMsg* response, - ::google::protobuf::Closure* done) { - brpc::ClosureGuard done_guard(done); - std::string message_name = request->message_name(); - auto itr = handler_map_.find(message_name); - brpc::Controller* cntl = static_cast(controller); - PADDLE_ENFORCE_NE( - itr, handler_map_.end(), - platform::errors::InvalidArgument( - "HeterService::SendAndRecvVariable Get illegal message_name: %s " - "which is not in HeterService::handler_map_", - message_name)); - itr->second(request, response, cntl); - } - - void RegisterServiceHandler(std::string message_name, - HeterServiceHandler func) { - handler_map_[message_name] = func; - } - - int32_t ForceExit() { - VLOG(3) << "heter service force exit"; - is_exit_ = true; - return 0; - } - - void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } - void SetFanin(const int& fan_in) { fan_in_ = fan_in; } - bool IsExit() { return is_exit_; } - - private: - int32_t stop_profiler(const PsRequestMessage& request, - PsResponseMessage& response, // NOLINT - brpc::Controller* cntl); - - int32_t start_profiler(const PsRequestMessage& request, - PsResponseMessage& response, // NOLINT - brpc::Controller* cntl); + void SetScope(const framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } - int32_t stop_heter_worker(const PsRequestMessage& request, - PsResponseMessage& response, // NOLINT - brpc::Controller* cntl); + virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) = 0; - private: - std::string endpoint_; - std::unordered_map handler_map_; - std::unordered_map _service_handler_map; - std::unordered_set stop_cpu_worker_set_; - int fan_in_; - bool is_exit_ = false; + protected: + const platform::DeviceContext* dev_ctx_; + const framework::Scope* scope_; }; using SharedMiniScope = @@ -163,31 +95,15 @@ using SharedTaskQueue = std::shared_ptr< std::unordered_map>>>>; -class HeterRequestHandler { - public: - HeterRequestHandler() : dev_ctx_(nullptr), scope_(nullptr) {} - - virtual ~HeterRequestHandler() {} - - void SetScope(const framework::Scope* scope) { scope_ = scope; } - void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } - - virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, - brpc::Controller* cntl) = 0; - - protected: - const platform::DeviceContext* dev_ctx_; - const framework::Scope* scope_; -}; - -class RequestSendAndRecvHandler final : public HeterRequestHandler { +class SendAndRecvVariableHandler final : public ServiceHandlerBase { public: - RequestSendAndRecvHandler() { + SendAndRecvVariableHandler() { this->num_microbatch_ = 0; this->num_minibatch_ = 0; + _local_shards.reset(new shard_type[FLAGS_heter_world_size]); } - virtual ~RequestSendAndRecvHandler() {} + virtual ~SendAndRecvVariableHandler() {} void SetMiniScopes(SharedMiniScope mini_scopes) { mini_scopes_ = mini_scopes; @@ -209,11 +125,47 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { return (*task_queue_).size(); } + int SaveInSwitchWithScope(const MultiVarMsg* request, + PsResponseMessage* response, + brpc::Controller* cntl); + + void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) { + auto& local_shard = _local_shards[group_id]; + while (local_shard.find(var_name) != local_shard.end()) { + if (local_shard[var_name].size() == 0) { + break; + } + VLOG(4) << "waiting consume result......"; + sleep(1); + } + return; + } + + void WaitForVarsProduced(int32_t group_id, const std::string& var_name) { + auto& local_shard = _local_shards[group_id]; + while (local_shard.find(var_name) == local_shard.end()) { + VLOG(4) << "waiting produce result......"; + sleep(1); + } + return; + } + + int SaveInSwitchWithShard(const MultiVarMsg* request, + PsResponseMessage* response, + brpc::Controller* cntl); + + int QueryInSwitchWithShard(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl); + + int QueryInSwitchWithScope(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl); + void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; } int Handle(const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) override { - platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle", + LOG(INFO) << "entered Handle"; + platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle", platform::TracerEventType::Communication, 1); FLAGS_eager_delete_tensor_gb = -1; @@ -241,7 +193,6 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { auto* tensor = var->GetMutable(); auto data = reinterpret_cast(tensor->data()); auto micro_id = static_cast(data[0]); - int minibatch_index = micro_id / 10; int microbatch_index = micro_id % 10; @@ -249,10 +200,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { std::unique_lock lk(scope_mutex_); if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) { lk.unlock(); - // PADDLE_ENFORCE_EQ( - // (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1, - // platform::errors::InvalidArgument( - // "minibatch index should in current trainer")); + PADDLE_ENFORCE_EQ( (*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1, platform::errors::InvalidArgument( @@ -282,6 +230,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { // blocking queue handles multi thread (*task_queue_)[minibatch_index]->Push( std::make_pair(message_name, microbatch_index)); + auto response_var_nums = request->recv_var_names_size(); std::vector response_var_names(response_var_nums), empty_var_names{}; @@ -295,6 +244,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { return 0; } + public: + using shard_type = SparseTableShard; + std::shared_ptr local_scope_ptr; // for switch + std::unordered_map vars_table; + std::unique_ptr _local_shards; + private: // share with HeterPipelineTrainer SharedMiniScope mini_scopes_{nullptr}; @@ -310,15 +265,254 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { SharedTaskQueue task_queue_; }; +class HeterService : public PsService { + public: + HeterService() { + _service_handler_map[PS_STOP_SERVER] = + std::bind(&HeterService::stop_heter_worker, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + _service_handler_map[PS_START_PROFILER] = + std::bind(&HeterService::start_profiler, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + _service_handler_map[PS_STOP_PROFILER] = + std::bind(&HeterService::stop_profiler, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3); + + service_handler_.local_scope_ptr = + std::make_shared(); + } + + virtual ~HeterService() {} + + virtual void service(::google::protobuf::RpcController* controller, + const PsRequestMessage* request, + PsResponseMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + + response->set_err_code(0); + response->set_err_msg(""); + brpc::Controller* cntl = static_cast(controller); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + return; + } + serviceHandler handler = itr->second; + int service_ret = handler(*request, *response, cntl); + VLOG(4) << "handler in service ret: " << service_ret; + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } + } + + virtual void SendAndRecvVariable( + ::google::protobuf::RpcController* controller, const MultiVarMsg* request, + MultiVarMsg* response, ::google::protobuf::Closure* done) { + // This object helps you to call done->Run() in RAII style. If you need + // to process the request asynchronously, pass done_guard.release(). + brpc::ClosureGuard done_guard(done); + std::string message_name = request->message_name(); + VLOG(0) << "SendAndRecvVariable message_name: " << message_name; + auto itr = handler_map_.find(message_name); + brpc::Controller* cntl = static_cast(controller); + LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side(); + PADDLE_ENFORCE_NE( + itr, handler_map_.end(), + platform::errors::InvalidArgument( + "HeterService::SendAndRecvVariable Get illegal message_name: %s " + "which is not in HeterService::handler_map_", + message_name)); + itr->second(request, response, cntl); + // We don't want to call done->Run() here, release the guard. + // done_guard.release(); + } + + virtual void RecvFromSwitch(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, MultiVarMsg* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + brpc::Controller* cntl = static_cast(controller); + // int ret = service_handler_.QueryInSwitchWithScope(request, response, + // cntl); + int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl); + // std::string message_name = request->message_name(); + // auto itr = handler_map_.find(message_name); + // int ret = itr->second(request, response, cntl); + if (ret != 0) { + LOG(ERROR) << "QueryInSwitchWithScope failed!"; + } + // response->set_message_name(message_name); + } + + virtual void SendToSwitch(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, + PsResponseMessage* response, + ::google::protobuf::Closure* done) { + VLOG(4) << "entering SendToSwitch"; + brpc::ClosureGuard done_guard(done); + auto& switch_client_ptr_ = + HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH); + if (switch_client_ptr_.peer_switch_channels_.empty()) { + LOG(ERROR) << "switch_client_ptr_.peer_switch_channels_ null"; + } + brpc::Channel* channel = switch_client_ptr_.peer_switch_channels_[0].get(); + brpc::Controller* cntl = static_cast(controller); + // proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset) + OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) { + auto* closure = reinterpret_cast(done); + int ret = closure->CheckResponse(); + closure->set_promise_value(ret); + if (closure->cntl.Failed()) { + PADDLE_ENFORCE_NE( + closure->cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendS2S meets brpc error, error message is %s", + closure->cntl.ErrorText())); + } + }); + auto& std_cntl = closure2->cntl; + std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + std_cntl.request_attachment().append(cntl->request_attachment().movable()); + + auto promise = std::make_shared>(); + closure2->add_promise(promise); + std::future fut = promise->get_future(); + // brpc::Controller std_cntl; + // std_cntl.request_attachment().append(cntl->request_attachment().movable()); + PsService_Stub stub(channel); + stub.SendS2S(&std_cntl, request, response, closure2); + cntl->response_attachment().append( + std_cntl.response_attachment().movable()); + fut.wait(); + VLOG(4) << "SendToSwitch done"; + } + + void SendS2S(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, PsResponseMessage* response, + ::google::protobuf::Closure* done) { + VLOG(4) << "entering SendS2S"; + brpc::ClosureGuard done_guard(done); + brpc::Controller* cntl = static_cast(controller); + // int ret = service_handler_.SaveInSwitchWithScope(request, response, + // cntl); + int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl); + // std::string message_name = request->message_name(); + // auto itr = handler_map_.find(message_name); + // if (itr == handler_map_.end()) { + // LOG(ERROR) << "can not find func handler"; + //} + // int ret = itr->second(request, response, cntl); + if (ret != 0) { + LOG(ERROR) << "SaveInSwitchWithScope failed"; + } + std::string err_msg = "ok"; + response->set_err_msg(err_msg.c_str()); + response->set_err_code(ret); + VLOG(4) << "heter server SendS2S done"; + } + + void SendToWorker(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, PsResponseMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + brpc::Controller* cntl = static_cast(controller); + VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side(); + auto& switch_client_ptr_ = + HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER); + VLOG(4) << "in switch client, peer worker 0: " + << switch_client_ptr_.peer_worker_list_[0]; + brpc::Channel* channel = switch_client_ptr_.peer_worker_channels_[0].get(); + + auto* closure = reinterpret_cast(done); + PsService_Stub stub(channel); + stub.SendAndRecvVariable(controller, request, &closure->response, done); + // fill response content + std::string err_msg("pass to worker"); + response->set_err_msg(err_msg.c_str()); + response->set_err_code(0); + } + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + handler_map_[message_name] = func; + } + + void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } + + void SetInterEndpoint(const std::string& end_point) { + endpoint_inter_ = end_point; + } + + void SetPeerEndPoints(const std::vector& peer_endpoints) { + peer_endpoints_ = peer_endpoints; + } + + void SetFanin(const int& fan_in) { fan_in_ = fan_in; } + + void ForceExit() { + VLOG(3) << "heter service force exit"; + is_exit_ = true; + return; + } + + bool IsExit() { return is_exit_; } + + private: + int32_t stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, // NOLINT + brpc::Controller* cntl) { + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("heter_worker_%s_profile", endpoint_)); + return 0; + } + + int32_t start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, // NOLINT + brpc::Controller* cntl) { + platform::EnableProfiler(platform::ProfilerState::kAll); + return 0; + } + + int32_t stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, // NOLINT + brpc::Controller* cntl) { + auto client_id = request.client_id(); + stop_cpu_worker_set_.insert(client_id); + if (stop_cpu_worker_set_.size() == fan_in_) { + is_exit_ = true; + } + return 0; + } + + private: + SendAndRecvVariableHandler service_handler_; + std::string endpoint_; + std::string endpoint_inter_; + // for switch + std::vector peer_endpoints_; + + std::unordered_map _service_handler_map; + std::unordered_map handler_map_; + std::unordered_set stop_cpu_worker_set_; + uint32_t fan_in_; + bool is_exit_ = false; +}; + class HeterServer { public: + HeterServer() : ready_(0) {} virtual ~HeterServer() {} - void Stop() { std::unique_lock lock(mutex_); if (stoped_ == true) return; - if (!IsExit()) service_.ForceExit(); - VLOG(3) << "HeterServer Stop()"; + if (!IsExit()) { + service_.ForceExit(); + } stoped_ = true; cv_.notify_all(); server_.Stop(1000); @@ -327,26 +521,42 @@ class HeterServer { bool IsStop() { std::unique_lock lock(mutex_); - if (stoped_ == true) - return true; - else - return false; + return stoped_; } bool IsExit() { return service_.IsExit(); } - HeterServer() : service_(), ready_(0) {} - void RegisterServiceHandler(std::string message_name, HeterServiceHandler func); - void StartHeterService(); + void StartHeterService(bool need_encrypt = false); + + void StartHeterInterService(bool need_encrypt = false); + + void SetEndPoint(const std::string& endpoint) { + this->endpoint_ = endpoint; + service_.SetEndpoint(endpoint); + } + + void SetLocalScope() { + request_handler_->local_scope_ptr = + std::make_shared(); + } + + void SetInterEndpoint(const std::string& endpoint) { + this->endpoint_inter_ = endpoint; + service_.SetInterEndpoint(endpoint); + } + + void SetPeerEndPoints(const std::vector& peer_endpoints) { + this->peer_endpoints_ = peer_endpoints; + service_.SetPeerEndPoints(peer_endpoints); + } - void SetEndPoint(const std::string& endpoint); void SetFanin(const int& fan_in); - void SetRequestHandler( - std::shared_ptr request_handler) { + void SetServiceHandler( + std::shared_ptr request_handler) { request_handler_ = request_handler; } @@ -381,11 +591,15 @@ class HeterServer { std::condition_variable condition_ready_; bool stoped_ = true; std::string endpoint_; + std::string endpoint_inter_; + // for switch + std::vector peer_endpoints_; protected: brpc::Server server_; + brpc::Server server_inter_; HeterService service_; - std::shared_ptr request_handler_; + std::shared_ptr request_handler_; DISABLE_COPY_AND_ASSIGN(HeterServer); std::mutex mutex_ready_; diff --git a/paddle/fluid/distributed/ps/service/sendrecv.proto b/paddle/fluid/distributed/ps/service/sendrecv.proto old mode 100644 new mode 100755 index 6dfaff1ffa..580f411c28 --- a/paddle/fluid/distributed/ps/service/sendrecv.proto +++ b/paddle/fluid/distributed/ps/service/sendrecv.proto @@ -59,6 +59,12 @@ enum PsCmdID { PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39; PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40; + PEER_ROLE_IS_WORKER = 41; + PEER_ROLE_IS_SWITCH = 42; + PS_SAVE_WITH_SCOPE = 43; + PS_SAVE_WITH_SHARD = 44; + PS_QUERY_WITH_SCOPE = 45; + PS_QUERY_WITH_SHARD = 46; } message PsRequestMessage { @@ -117,9 +123,16 @@ message MultiVariableMessage { repeated string send_var_names = 2; repeated string recv_var_names = 3; repeated VariableMessage var_messages = 4; + optional bytes data = 5; + repeated int32 vars_len = 6; + optional int32 group_id = 7; }; service PsService { rpc service(PsRequestMessage) returns (PsResponseMessage); rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); + rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage); + rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage); + rpc SendS2S(MultiVariableMessage) returns (PsResponseMessage); + rpc RecvFromSwitch(MultiVariableMessage) returns (MultiVariableMessage); }; diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 09ced6bd0d..e92e160c7a 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -300,7 +300,7 @@ if(WITH_DISTRIBUTE) lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor heter_service_proto fleet_executor ${BRPC_DEP}) - set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") @@ -320,7 +320,7 @@ if(WITH_DISTRIBUTE) index_sampler index_wrapper sampler index_dataset_proto lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor) - set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") diff --git a/paddle/fluid/operators/pscore/CMakeLists.txt b/paddle/fluid/operators/pscore/CMakeLists.txt old mode 100644 new mode 100755 index baf82a9df3..863370540d --- a/paddle/fluid/operators/pscore/CMakeLists.txt +++ b/paddle/fluid/operators/pscore/CMakeLists.txt @@ -6,9 +6,9 @@ include(operators) set(DISTRIBUTE_DEPS "") -list(APPEND DISTRIBUTE_DEPS fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context) +list(APPEND DISTRIBUTE_DEPS executor fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context) -set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") +set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(DISTRIBUTE_COMPILE_FLAGS @@ -37,3 +37,6 @@ cc_test(send_and_recv_gpu_test SRCS send_and_recv_op_gpu_test.cc DEPS executor s set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) + +#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) diff --git a/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc b/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc new file mode 100644 index 0000000000..2340f443c4 --- /dev/null +++ b/paddle/fluid/operators/pscore/heter_cloud_comm_cpu_test.cc @@ -0,0 +1,247 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined PADDLE_WITH_PSCORE +#include + +#include +#include +#include +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps/service/heter_client.h" +#include "paddle/fluid/distributed/ps/service/heter_server.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace distributed = paddle::distributed; + +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; + +void CreateVarsOnScope(framework::Scope* scope) { + auto var1 = scope->Var("w"); + var1->GetMutable(); + auto var2 = scope->Var("x"); + var2->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope); + + auto w = scope->Var("w")->GetMutable(); + auto w_value = w->mutable_value(); + w_value->Resize({rows_numel, 10}); + for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); + + auto ptr = w_value->mutable_data(*place); + + for (int64_t i = 0; i < w_value->numel(); ++i) { + ptr[i] = static_cast(i / 10); + } + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) { + x_ptr[i] = 1.0; + } +} + +void StartSwitchServer( + std::shared_ptr& switch_server_ptr, // NOLINT + std::vector endpoints, + std::vector peer_endpoints) { + switch_server_ptr->SetPeerEndPoints(peer_endpoints); + switch_server_ptr->SetEndPoint(endpoints[0]); + /* + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::SendAndRecvVariableHandler()); + switch_server_ptr->SetServiceHandler(b_req_handler); + + switch_server_ptr->SetLocalScope(); + + switch_server_ptr->RegisterServiceHandler( + std::to_string(distributed::PS_SAVE_WITH_SCOPE), + [&](const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) -> int { + return b_req_handler->SaveInSwitchWithScope(request, response, cntl); + }); + + switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_SAVE_WITH_SHARD), + [&](const MultiVarMsg* request, MultiVarMsg* + response, + brpc::Controller* cntl) -> int { + return b_req_handler->SaveInSwitchWithShard( + request, response, cntl); + }); + + switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SCOPE), + [&](const MultiVarMsg* request, MultiVarMsg* + response, + brpc::Controller* cntl) -> int { + return b_req_handler->QueryInSwitchWithScope( + request, response, cntl); + }); + + switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SHARD), + [&](const MultiVarMsg* request, MultiVarMsg* + response, + brpc::Controller* cntl) -> int { + return b_req_handler->QueryInSwitchWithShard( + request, response, cntl); + }); + */ + switch_server_ptr->StartHeterService(false); +} + +void StartSwitchInterServer( + std::shared_ptr& switch_server_ptr, // NOLINT + std::vector endpoints, + std::vector peer_endpoints) { + switch_server_ptr->SetPeerEndPoints(peer_endpoints); + switch_server_ptr->SetInterEndpoint(endpoints[1]); + switch_server_ptr->StartHeterInterService(false); +} + +TEST(HETERSENDANDRECV, CPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + + // 启动 switch server A & B + std::string switch_a_endpoint("127.0.0.1:6000"); + std::string switch_a_endpoint_inter("127.0.0.1:6100"); + std::string switch_b_endpoint_inter("127.0.0.1:7100"); + std::string switch_b_endpoint("127.0.0.1:7000"); + + std::shared_ptr switch_server_ptr_a = + std::make_shared(); + std::vector end_points{switch_a_endpoint}; + std::vector peer_endpoints{switch_b_endpoint_inter}; + std::thread switch_server_a_thread(StartSwitchServer, + std::ref(switch_server_ptr_a), end_points, + peer_endpoints); + switch_server_ptr_a->WaitServerReady(); + + std::shared_ptr switch_server_ptr_b = + std::make_shared(); + end_points = {switch_b_endpoint, switch_b_endpoint_inter}; + peer_endpoints = {}; + std::thread switch_server_b_thread(StartSwitchServer, + std::ref(switch_server_ptr_b), end_points, + peer_endpoints); + switch_server_ptr_b->WaitServerReady(); + + end_points = {switch_b_endpoint, switch_b_endpoint_inter}; + peer_endpoints = {}; + std::thread switch_server_b_thread_inter(StartSwitchInterServer, + std::ref(switch_server_ptr_b), + end_points, peer_endpoints); + switch_server_ptr_b->WaitServerReady(); + + // 获取 client 实例 + std::shared_ptr heter_client_ptr_ = + distributed::HeterClient::GetInstance( + {switch_a_endpoint, switch_b_endpoint}, {}, 0); + + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + framework::Executor exe(place); + + framework::ProgramDesc program; + exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc + std::shared_ptr send_scope_ptr = + std::make_shared(); + int64_t rows_numel = 10; + InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel); + LOG(INFO) << "InitTensorsOnClient done"; + + auto send_async = [&]() -> void { + /* + //std::string message_name = + std::to_string(distributed::PS_SAVE_WITH_SCOPE); + std::string message_name = "send and save"; + std::vector send_var_names{"w", "x"}; + int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name, + send_var_names); + if (!ret) { + LOG(ERROR) << ">>>> worker send success"; + } + */ + ///* + std::vector vars_len{2, 4}; + std::vector values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + int64_t data_size = 6; + std::vector send_var_names{"w", "x"}; + int group_id = 0; + int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len, + values.data(), data_size); + if (!ret) { + LOG(INFO) << ">>>> worker send success"; + } + //*/ + }; + std::thread send_thread(send_async); + /* + std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE); + std::vector recv_var_names{"w", "x"}; + std::shared_ptr recv_scope_ptr = + std::make_shared(); + int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name, + recv_var_names); + if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) { + LOG(INFO) << ">>>> worker recv success"; + } else { + LOG(INFO) << "worker recv failed"; + } + */ + ///* + int group_id = 0; + std::vector recv_var_names{"w", "x"}; + std::vector values; + int data_size = 6; + values.resize(data_size); + int ret = heter_client_ptr_->Recv(group_id, recv_var_names, values.data(), + data_size); + if (!ret) { + VLOG(4) << "queried data is: "; + for (auto f : values) { + VLOG(4) << f << " "; + } + LOG(INFO) << ">>>> worker recv success"; + } + //*/ + + send_thread.join(); + + switch_server_ptr_a->Stop(); + LOG(INFO) << "switch server A stopped"; + + switch_server_ptr_b->Stop(); + LOG(INFO) << "switch server B stopped"; + + switch_server_a_thread.join(); + LOG(INFO) << "switch_server_a_thread joined"; + + switch_server_b_thread.join(); + LOG(INFO) << "switch_server_b_thread joined"; + + switch_server_b_thread_inter.join(); + LOG(INFO) << "switch_server_b_thread_inter joined"; +} +#endif diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc index 2c443e8c63..2df0d7526a 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc @@ -88,21 +88,20 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const { for (size_t blkid = 1; blkid < num_blocks; ++blkid) { block_list.push_back(blkid); } - for (size_t i = 0; i < block_list.size(); ++i) { auto blkid = block_list[i]; auto it = message_to_block_id.find_value(blkid); - rpc_service_->RegisterServiceHandler( + heter_server_->RegisterServiceHandler( it->first, [&](const MultiVarMsg *request, MultiVarMsg *response, brpc::Controller *cntl) -> int { - return request_send_and_recv_handler_->Handle(request, response, - cntl); + return send_and_recv_variable_handler_->Handle(request, response, + cntl); }); } while (true) { - if (rpc_service_->IsExit() || rpc_service_->IsStop()) { - rpc_service_->Stop(); + if (heter_server_->IsExit() || heter_server_->IsStop()) { + heter_server_->Stop(); VLOG(0) << "get exit. rpc_processor stop!"; break; } @@ -110,8 +109,9 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const { } // while(true) } -void RunServer(std::shared_ptr service) { - service->StartHeterService(); +void RunServer( + std::shared_ptr heter_server_ptr) { + heter_server_ptr->StartHeterService(); } void HeterListenAndServOp::RunImpl(const framework::Scope &scope, @@ -126,16 +126,16 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, auto fan_in = Attr("fanin"); auto inputs = Inputs("X"); - PADDLE_ENFORCE_EQ(rpc_service_, nullptr, + PADDLE_ENFORCE_EQ(heter_server_, nullptr, platform::errors::PreconditionNotMet( "RPC service has been created unexpectedly.")); std::string endpoint = Attr("endpoint"); VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint; - rpc_service_ = distributed::HeterServer::GetInstance(); - rpc_service_->SetEndPoint(endpoint); - rpc_service_->SetFanin(fan_in); + heter_server_ = distributed::HeterServer::GetInstance(); + heter_server_->SetEndPoint(endpoint); + heter_server_->SetFanin(fan_in); auto optimize_blocks = Attr>("optimize_blocks"); @@ -146,20 +146,18 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, auto *program = optimize_blocks[0]->Program(); - request_send_and_recv_handler_.reset( - new distributed::RequestSendAndRecvHandler()); - request_send_and_recv_handler_->SetScope(&scope); - request_send_and_recv_handler_->SetDevCtx(&dev_ctx); - rpc_service_->SetRequestHandler(request_send_and_recv_handler_); + send_and_recv_variable_handler_.reset( + new distributed::SendAndRecvVariableHandler()); + send_and_recv_variable_handler_->SetScope(&scope); + send_and_recv_variable_handler_->SetDevCtx(&dev_ctx); + heter_server_->SetServiceHandler(send_and_recv_variable_handler_); VLOG(2) << "RunAsyncLoop"; - auto message_to_block_id_str = - Attr>("message_to_block_id"); // start the server listening after all member initialized. - server_thread_.reset(new std::thread(RunServer, rpc_service_)); + server_thread_.reset(new std::thread(RunServer, heter_server_)); VLOG(3) << "wait server thread to become ready..."; - rpc_service_->WaitServerReady(); + heter_server_->WaitServerReady(); RunAsyncLoop(program); VLOG(3) << "Wait for Server_thread_ stop"; (server_thread_.get())->join(); diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h old mode 100644 new mode 100755 index 2d2d8abe70..3ecff083b0 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h @@ -34,7 +34,7 @@ limitations under the License. */ namespace paddle { namespace distributed { -class HeterRequestHandler; +class ServiceHandlerBase; class HeterServer; } // namespace distributed } // namespace paddle @@ -82,10 +82,10 @@ class HeterListenAndServOp : public framework::OperatorBase { const platform::Place& dev_place) const override; protected: - mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr heter_server_; mutable std::shared_ptr server_thread_; - mutable std::shared_ptr - request_send_and_recv_handler_; + mutable std::shared_ptr + send_and_recv_variable_handler_; }; } // namespace operators diff --git a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc index b024fe76b0..ab2fcba510 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc @@ -142,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, CreateVarsOnScope(scope, place); } -void StartHeterServer(std::string endpoint) { +void RunHeterServerOp(std::string endpoint) { framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; @@ -167,10 +167,10 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { std::string previous_endpoint = endpoint; LOG(INFO) << "before StartSendAndRecvServer"; FLAGS_eager_delete_tensor_gb = -1; - std::thread server_thread(StartHeterServer, endpoint); + std::thread server_thread(RunHeterServerOp, endpoint); sleep(1); - auto b_rpc_service = distributed::HeterServer::GetInstance(); - b_rpc_service->WaitServerReady(); + auto heter_server_ptr_ = distributed::HeterServer::GetInstance(); + heter_server_ptr_->WaitServerReady(); using MicroScope = std::unordered_map>>; using MiniScope = std::unordered_map; @@ -185,8 +185,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { (*micro_scope).push_back(micro_scope_0); (*micro_scope).push_back(micro_scope_1); (*micro_scopes)[0] = micro_scope; - b_rpc_service->SetMicroBatchScopes(micro_scopes); - b_rpc_service->SetMiniBatchScopes(mini_scopes); + heter_server_ptr_->SetMicroBatchScopes(micro_scopes); + heter_server_ptr_->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_map>>(); - b_rpc_service->SetTaskQueue(task_queue_); + heter_server_ptr_->SetTaskQueue(task_queue_); LOG(INFO) << "before HeterClient::GetInstance"; - distributed::HeterClient* rpc_client = + distributed::HeterClient* heter_client_ptr_ = distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) .get(); - PADDLE_ENFORCE_NE(rpc_client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); - framework::Scope* scope = (*micro_scope)[0]; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); @@ -224,8 +220,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { std::vector recv_var = {}; LOG(INFO) << "before SendAndRecvAsync"; - rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var, - "forward"); + heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, + recv_var, "forward"); auto task = (*task_queue_)[0]->Pop(); PADDLE_ENFORCE_EQ( task.first, "x", @@ -234,15 +230,15 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel); LOG(INFO) << "before SendAndRecvAsync 2"; - rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, send_var, - recv_var, "backward"); + heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, + send_var, recv_var, "backward"); auto task2 = (*task_queue_)[0]->Pop(); PADDLE_ENFORCE_EQ( task2.first, "x", platform::errors::InvalidArgument( "Recv message and Send message name not match, Check your Code")); - rpc_client->Stop(); + heter_client_ptr_->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); LOG(INFO) << "end server thread join"; diff --git a/paddle/fluid/operators/pscore/heter_server_test.cc b/paddle/fluid/operators/pscore/heter_server_test.cc index 6ab4204b2f..d4ee00d10a 100644 --- a/paddle/fluid/operators/pscore/heter_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_server_test.cc @@ -34,8 +34,6 @@ using VarMsg = ::paddle::distributed::VariableMessage; USE_OP_ITSELF(scale); -std::shared_ptr b_rpc_service; - std::string get_ip_port() { std::mt19937 rng; rng.seed(std::random_device()()); @@ -171,31 +169,32 @@ void StartSendAndRecvServer(std::string endpoint) { InitTensorsOnServer(&scope, &place, 10); LOG(INFO) << "end InitTensorsOnServer"; - std::shared_ptr b_req_handler; - b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::SendAndRecvVariableHandler()); LOG(INFO) << "before SetDevCtx"; b_req_handler->SetDevCtx(&ctx); LOG(INFO) << "before SetScope"; b_req_handler->SetScope(&scope); LOG(INFO) << "before HeterServer::GetInstance"; - b_rpc_service = distributed::HeterServer::GetInstance(); - b_rpc_service->SetEndPoint(endpoint); + std::shared_ptr heter_server_ptr_ = + distributed::HeterServer::GetInstance(); + heter_server_ptr_->SetEndPoint(endpoint); LOG(INFO) << "before HeterServer::RegisterServiceHandler"; - b_rpc_service->RegisterServiceHandler( + heter_server_ptr_->RegisterServiceHandler( in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) -> int { return b_req_handler->Handle(request, response, cntl); }); - b_rpc_service->RegisterServiceHandler( + heter_server_ptr_->RegisterServiceHandler( in_var_name2, [&](const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) -> int { return b_req_handler->Handle(request, response, cntl); }); - b_rpc_service->SetRequestHandler(b_req_handler); + heter_server_ptr_->SetServiceHandler(b_req_handler); LOG(INFO) << "before HeterServer::RunServer"; - RunServer(b_rpc_service); - // std::thread server_thread(std::bind(RunServer, b_rpc_service)); + RunServer(heter_server_ptr_); + // std::thread server_thread(std::bind(RunServer, heter_server_ptr_)); // server_thread.join(); } @@ -206,9 +205,10 @@ TEST(SENDANDRECV, CPU) { std::string endpoint = get_ip_port(); std::string previous_endpoint = endpoint; LOG(INFO) << "before StartSendAndRecvServer"; - b_rpc_service = distributed::HeterServer::GetInstance(); + std::shared_ptr heter_server_ptr_ = + distributed::HeterServer::GetInstance(); std::thread server_thread(StartSendAndRecvServer, endpoint); - b_rpc_service->WaitServerReady(); + heter_server_ptr_->WaitServerReady(); using MicroScope = std::unordered_map>>; using MiniScope = std::unordered_map; @@ -223,8 +223,8 @@ TEST(SENDANDRECV, CPU) { (*micro_scope).push_back(micro_scope_0); (*micro_scope).push_back(micro_scope_1); (*micro_scopes)[0] = micro_scope; - b_rpc_service->SetMicroBatchScopes(micro_scopes); - b_rpc_service->SetMiniBatchScopes(mini_scopes); + heter_server_ptr_->SetMicroBatchScopes(micro_scopes); + heter_server_ptr_->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_map>>(); - b_rpc_service->SetTaskQueue(task_queue_); + heter_server_ptr_->SetTaskQueue(task_queue_); LOG(INFO) << "before HeterClient::GetInstance"; - distributed::HeterClient* rpc_client = + distributed::HeterClient* heter_client_ptr_ = distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) .get(); - PADDLE_ENFORCE_NE(rpc_client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); - framework::Scope* scope = (*micro_scope)[0]; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); @@ -262,8 +258,8 @@ TEST(SENDANDRECV, CPU) { std::vector recv_var = {}; LOG(INFO) << "before SendAndRecvAsync"; - rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var, - "forward"); + heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, + recv_var, "forward"); LOG(INFO) << "client wait for Pop"; auto task = (*task_queue_)[0]->Pop(); @@ -276,8 +272,8 @@ TEST(SENDANDRECV, CPU) { InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel); LOG(INFO) << "before SendAndRecvAsync 2"; std::string in_var_name2("y"); - rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2, - send_var, recv_var, "backward"); + heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2, + send_var, recv_var, "backward"); LOG(INFO) << "after SendAndRecvAsync 2"; auto task2 = (*task_queue_)[0]->Pop(); @@ -286,8 +282,7 @@ TEST(SENDANDRECV, CPU) { platform::errors::InvalidArgument( "Recv message and Send message name not match, Check your Code")); - rpc_client->FinalizeWorker(); - b_rpc_service->Stop(); + heter_server_ptr_->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); LOG(INFO) << "end server thread join"; diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc old mode 100644 new mode 100755 index 26da0d3696..7c25d38d1e --- a/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc @@ -36,8 +36,6 @@ using VarMsg = ::paddle::distributed::VariableMessage; USE_OP_ITSELF(scale); USE_OP(send_and_recv); -std::shared_ptr b_rpc_service; - std::string get_ip_port() { std::mt19937 rng; rng.seed(std::random_device()()); @@ -148,14 +146,15 @@ void StartSendAndRecvServer(std::string endpoint) { InitTensorsOnServer(&scope, &place, 10); LOG(INFO) << "end InitTensorsOnServer"; - std::shared_ptr b_req_handler; - b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::SendAndRecvVariableHandler()); LOG(INFO) << "before SetDevCtx"; b_req_handler->SetDevCtx(&ctx); LOG(INFO) << "before SetScope"; b_req_handler->SetScope(&scope); LOG(INFO) << "before HeterServer::GetInstance"; - b_rpc_service = distributed::HeterServer::GetInstance(); + std::shared_ptr b_rpc_service = + distributed::HeterServer::GetInstance(); b_rpc_service->SetEndPoint(endpoint); LOG(INFO) << "before HeterServer::RegisterServiceHandler"; b_rpc_service->RegisterServiceHandler( @@ -164,7 +163,7 @@ void StartSendAndRecvServer(std::string endpoint) { return b_req_handler->Handle(request, response, cntl); }); - b_rpc_service->SetRequestHandler(b_req_handler); + b_rpc_service->SetServiceHandler(b_req_handler); LOG(INFO) << "before HeterServer::RunServer"; RunServer(b_rpc_service); @@ -179,7 +178,8 @@ TEST(SENDANDRECV, CPU) { std::string endpoint = get_ip_port(); std::string previous_endpoint = endpoint; LOG(INFO) << "before StartSendAndRecvServer"; - b_rpc_service = distributed::HeterServer::GetInstance(); + std::shared_ptr b_rpc_service = + distributed::HeterServer::GetInstance(); std::thread server_thread(StartSendAndRecvServer, endpoint); b_rpc_service->WaitServerReady(); using MicroScope = @@ -292,7 +292,6 @@ TEST(SENDANDRECV, CPU) { platform::errors::InvalidArgument( "Recv message and Send message name not match, Check your Code")); - rpc_client->FinalizeWorker(); b_rpc_service->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc index a5e292a05e..4054846460 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc @@ -167,8 +167,8 @@ void StartSendAndRecvServer(std::string endpoint) { InitTensorsOnServer(&scope, &place, 10); LOG(INFO) << "end InitTensorsOnServer"; - std::shared_ptr b_req_handler; - b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::SendAndRecvVariableHandler()); LOG(INFO) << "before SetDevCtx"; b_req_handler->SetDevCtx(&ctx); LOG(INFO) << "before SetScope"; @@ -183,7 +183,7 @@ void StartSendAndRecvServer(std::string endpoint) { return b_req_handler->Handle(request, response, cntl); }); - b_rpc_service2->SetRequestHandler(b_req_handler); + b_rpc_service2->SetServiceHandler(b_req_handler); LOG(INFO) << "before HeterServer::RunServer"; RunServer(b_rpc_service2); @@ -228,13 +228,11 @@ TEST(SENDANDRECV, GPU) { b_rpc_service2->SetTaskQueue(task_queue_); LOG(INFO) << "before HeterClient::GetInstance"; - distributed::HeterClient* rpc_client = - distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) - .get(); - - PADDLE_ENFORCE_NE(rpc_client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); + std::shared_ptr heter_client_ptr_ = + distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0); + if (heter_client_ptr_ == nullptr) { + LOG(ERROR) << "heter_client_ptr_ is null"; + } framework::Scope* scope = (*micro_scope)[0]; platform::CUDAPlace place; @@ -316,7 +314,6 @@ TEST(SENDANDRECV, GPU) { platform::errors::InvalidArgument( "Recv message and Send message name not match, Check your Code")); - rpc_client->FinalizeWorker(); b_rpc_service2->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index f075439e54..5088ad3457 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -1174,6 +1174,7 @@ SIXTH_PARALLEL_JOB_NEW = [ ] LOWEST_PARALLEL_JOB_NEW = [ + 'heter_cloud_comm_cpu_test', 'heter_server_test', 'test_scatter_op', 'test_trt_convert_hard_sigmoid', -- GitLab