From 79a1a7cda0d45df31f6d4fa60bdc3f4e206e31bc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 18 Apr 2018 17:48:53 +0800 Subject: [PATCH] init async gprc server --- paddle/fluid/operators/CMakeLists.txt | 2 + .../operators/async_listen_and_serv_op.cc | 217 ++++++++++++ .../operators/async_listen_and_serv_op.h | 55 ++++ paddle/fluid/operators/detail/CMakeLists.txt | 2 +- .../operators/detail/async_grpc_server.cc | 311 ++++++++++++++++++ .../operators/detail/async_grpc_server.h | 111 +++++++ 6 files changed, 697 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/async_listen_and_serv_op.cc create mode 100644 paddle/fluid/operators/async_listen_and_serv_op.h create mode 100644 paddle/fluid/operators/detail/async_grpc_server.cc create mode 100644 paddle/fluid/operators/detail/async_grpc_server.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7d6781c2c38..f723652a8fa 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -189,6 +189,8 @@ if(WITH_DISTRIBUTE) set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + op_library(async_listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) + set_source_files_properties(async_listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) diff --git a/paddle/fluid/operators/async_listen_and_serv_op.cc b/paddle/fluid/operators/async_listen_and_serv_op.cc new file mode 100644 index 00000000000..4d66f9853a9 --- /dev/null +++ b/paddle/fluid/operators/async_listen_and_serv_op.cc @@ -0,0 +1,217 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT +#include + +#include "paddle/fluid/operators/async_listen_and_serv_op.h" + +namespace paddle { +namespace operators { + +void RunServer(std::shared_ptr service) { + service->RunAsyncUpdate(); + VLOG(4) << "RunServer thread end"; +} + +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VariableMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + +static void ParallelExecuteBlocks( + const std::vector ¶llel_blkids, framework::Executor *executor, + const std::vector> + &prepared, + framework::ProgramDesc *program, framework::Scope *scope) { + std::vector> fs; + for (size_t idx : parallel_blkids) { + fs.push_back( + framework::Async([&executor, &prepared, &program, &scope, idx]() { + int run_block = idx; // thread local + try { + executor->RunPreparedContext(prepared[run_block].get(), scope, + false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); + } + for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); +} + +static void AsyncExecuteBlock( + size_t block_id, framework::Executor *executor, + std::shared_ptr ctx, + framework::ProgramDesc *program, framework::Scope *scope) {} + +AsyncListenAndServOp::AsyncListenAndServOp( + const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + +int AsyncListenAndServOp::GetSelectedPort() const { + return rpc_service_->GetSelectedPort(); +} + +void AsyncListenAndServOp::Stop() { + rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); + server_thread_->join(); +} + +void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + framework::Scope &recv_scope = scope.NewScope(); + + if (!rpc_service_) { + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::SyncGRPCServer(endpoint)); + } + + auto *optimize_block = Attr(kOptimizeBlock); + auto *prefetch_block = Attr(kPrefetchBlock); + auto *program = optimize_block->Program(); + size_t num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + + framework::Executor executor(dev_place); + std::vector block_list; + for (size_t blkid = 1; blkid < num_blocks; ++blkid) { + if (blkid != static_cast(prefetch_block->ID())) { + block_list.push_back(blkid); + } + } + auto optimize_prepared = executor.Prepare(*program, block_list); + // Insert placeholder for block0 which holds current op itself. + optimize_prepared.insert( + optimize_prepared.begin(), + std::shared_ptr(nullptr)); + + rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); + // TODO(qiao) set proper fields for table lookup and update + rpc_service_->SetExecutor(&executor); + VLOG(3) << "prefetch block id is " << prefetch_block->ID(); + auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); + rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); + prefetch_prepared.release(); + rpc_service_->SetProgram(program); + // start the server listening after all member initialized. + server_thread_.reset(new std::thread(RunServer, rpc_service_)); + VLOG(3) << "wait server thread to become ready..."; + sleep(5); + // Write to a file of server selected port for python use. + std::ofstream port_file; + port_file.open("/tmp/paddle.selected_port"); + port_file << rpc_service_->GetSelectedPort(); + port_file.close(); + + bool exit_flag = false; + // Record received sparse variables, so that + // we could reset those after execute optimize program + std::vector sparse_vars; + while (!exit_flag) { + const detail::ReceivedMessage v = rpc_service_->Get(); + auto recv_var_name = v.first; + if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; + break; + } else { + VLOG(3) << "received grad: " << recv_var_name; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + if (var->IsType()) { + sparse_vars.push_back(var); + } + AsyncExecuteBlock(); + } + + if (exit_flag) { + rpc_service_->ShutDown(); + break; + } + + // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads + // and this will still work. + + // The optimize blocks which have the same parent ID would run parallel + // TODO(Yancey1989): need to use ParallelExecutor for future + int32_t last_parent_blkid = program->Block(1).Parent(); + + VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; + + // Reset the received sparse variables, the sum operator would not + // sum the input sparse variables which rows is empty at the next + // mini-batch. + // TODO(Yancey1989): move the reset action into an operator, we couldn't + // have any hide logic in the operator. + for (auto &var : sparse_vars) { + var->GetMutable()->mutable_rows()->clear(); + } + // FIXME(typhoonzero): use another condition to sync wait clients get. + sparse_vars.clear(); + } // while(true) +} + +class AsyncListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AsyncListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable(); + AddComment(R"DOC( +ListenAndServ operator + +This operator will start a RPC server which can receive variables +from send_op and send back variables to recv_op. +)DOC"); + AddAttr("endpoint", + "(string, default 127.0.0.1:6164)" + "IP address to listen on.") + .SetDefault("127.0.0.1:6164") + .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr(kOptimizeBlock, + "BlockID to run on server side."); + AddAttr(kPrefetchBlock, + "prefetch block to run on server side."); + AddAttr("Fanin", "How many clients send to this server.") + .SetDefault(1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(async_listen_and_serv, ops::AsyncListenAndServOp, + ops::AsyncListenAndServOpMaker); diff --git a/paddle/fluid/operators/async_listen_and_serv_op.h b/paddle/fluid/operators/async_listen_and_serv_op.h new file mode 100644 index 00000000000..9df351b929a --- /dev/null +++ b/paddle/fluid/operators/async_listen_and_serv_op.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/detail/async_grpc_server.h" + +namespace paddle { +namespace operators { + +constexpr char kOptimizeBlock[] = "OptimizeBlock"; +constexpr char kPrefetchBlock[] = "PrefetchBlock"; + +void RunServer(std::shared_ptr service); + +class AsyncListenAndServOp : public framework::OperatorBase { + public: + AsyncListenAndServOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs); + + int GetSelectedPort() const; + + void Stop() override; + + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override; + + protected: + mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr server_thread_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 719a7465b8d..059099ee0ab 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_DISTRIBUTE) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_server.cc async_grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr diff --git a/paddle/fluid/operators/detail/async_grpc_server.cc b/paddle/fluid/operators/detail/async_grpc_server.cc new file mode 100644 index 00000000000..7e45bc6b2f5 --- /dev/null +++ b/paddle/fluid/operators/detail/async_grpc_server.cc @@ -0,0 +1,311 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detail/async_grpc_server.h" + +#include +#include + +using ::grpc::ServerAsyncResponseWriter; + +namespace paddle { +namespace operators { +namespace detail { + +enum CallStatus { PROCESS = 0, FINISH }; + +// reference: +// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server +class RequestBase { + public: + explicit RequestBase(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + const platform::DeviceContext* dev_ctx) + : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { + PADDLE_ENFORCE(cq_); + } + virtual ~RequestBase() {} + virtual void Process() { assert(false); } + + CallStatus Status() { return status_; } + void SetStatus(CallStatus status) { status_ = status; } + virtual std::string GetReqName() { + assert(false); + return ""; + } + + protected: + ::grpc::ServerContext ctx_; + GrpcService::AsyncService* service_; + ::grpc::ServerCompletionQueue* cq_; + CallStatus status_; + const platform::DeviceContext* dev_ctx_; +}; + +class RequestSend final : public RequestBase { + public: + explicit RequestSend(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, ReceivedQueue* queue, + const platform::DeviceContext* dev_ctx) + : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { + request_.reset(new VariableResponse(scope, dev_ctx_)); + int method_id = static_cast(detail::GrpcMethod::kSendVariable); + service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, + cq_, cq_, this); + } + + virtual ~RequestSend() {} + + virtual std::string GetReqName() { return request_->Varname(); } + + virtual void Process() { + queue_->Push(std::make_pair(request_->Varname(), request_)); + + sendrecv::VoidMessage reply; + responder_.Finish(reply, ::grpc::Status::OK, this); + status_ = FINISH; + } + + protected: + std::shared_ptr request_; + ReceivedQueue* queue_; + ServerAsyncResponseWriter responder_; +}; + +class RequestGet final : public RequestBase { + public: + explicit RequestGet(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, + const platform::DeviceContext* dev_ctx) + : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) { + int method_id = static_cast(detail::GrpcMethod::kGetVariable); + service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, + cq_, this); + } + + virtual ~RequestGet() {} + + virtual std::string GetReqName() { return request_.varname(); } + + virtual void Process() { + // proc request. + std::string var_name = request_.varname(); + auto* var = scope_->FindVar(var_name); + + ::grpc::ByteBuffer reply; + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + + responder_.Finish(reply, ::grpc::Status::OK, this); + status_ = FINISH; + } + + protected: + sendrecv::VariableMessage request_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; + framework::Scope* scope_; +}; + +class RequestPrefetch final : public RequestBase { + public: + explicit RequestPrefetch(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, + const platform::DeviceContext* dev_ctx, + framework::Executor* executor, + framework::ProgramDesc* program, + framework::ExecutorPrepareContext* prefetch_ctx) + : RequestBase(service, cq, dev_ctx), + responder_(&ctx_), + scope_(scope), + executor_(executor), + program_(program), + prefetch_ctx_(prefetch_ctx) { + request_.reset(new VariableResponse(scope, dev_ctx_)); + int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); + service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, + cq_, cq_, this); + } + + virtual ~RequestPrefetch() {} + + virtual std::string GetReqName() { return request_->Varname(); } + + virtual void Process() { + // prefetch process... + ::grpc::ByteBuffer reply; + + std::string var_name = request_->OutVarname(); + VLOG(3) << "prefetch var " << var_name; + auto var_desc = program_->Block(0).FindVar(var_name); + framework::Scope* local_scope = &scope_->NewScope(); + auto* var = local_scope->FindVar(var_name); + InitializeVariable(var, var_desc->GetType()); + executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false); + + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + + responder_.Finish(reply, ::grpc::Status::OK, this); + status_ = FINISH; + } + + protected: + std::shared_ptr request_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; + framework::Scope* scope_; + framework::Executor* executor_; + framework::ProgramDesc* program_; + framework::ExecutorPrepareContext* prefetch_ctx_; +}; + +void SyncGRPCServer::RunAsyncUpdate() { + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), + &selected_port_); + builder.SetMaxSendMessageSize(std::numeric_limits::max()); + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + builder.RegisterService(&service_); + + cq_send_ = builder.AddCompletionQueue(); + cq_get_ = builder.AddCompletionQueue(); + cq_prefetch_ = builder.AddCompletionQueue(); + + server_ = builder.BuildAndStart(); + LOG(INFO) << "Server listening on " << address_ + << " selected port: " << selected_port_; + + std::function send_register = + std::bind(&SyncGRPCServer::TryToRegisterNewSendOne, this); + std::function get_register = + std::bind(&SyncGRPCServer::TryToRegisterNewGetOne, this); + std::function prefetch_register = + std::bind(&SyncGRPCServer::TryToRegisterNewPrefetchOne, this); + + // TODO(wuyi): Run these "HandleRequest" in thread pool + t_send_.reset( + new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this, + cq_send_.get(), "cq_send", send_register))); + t_get_.reset( + new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this, + cq_get_.get(), "cq_get", get_register))); + t_prefetch_.reset(new std::thread( + std::bind(&SyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), + "cq_prefetch", prefetch_register))); + // wait server + server_->Wait(); + t_send_->join(); + t_get_->join(); + t_prefetch_->join(); +} + +void SyncGRPCServer::ShutdownQueue() { + std::unique_lock lock(cq_mutex_); + cq_send_->Shutdown(); + cq_get_->Shutdown(); + cq_prefetch_->Shutdown(); +} + +// This URL explains why shutdown is complicate: +void SyncGRPCServer::ShutDown() { + is_shut_down_ = true; + ShutdownQueue(); + server_->Shutdown(); +} + +void SyncGRPCServer::TryToRegisterNewSendOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; + return; + } + RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, + &var_recv_queue_, dev_ctx_); + VLOG(4) << "Create RequestSend status:" << send->Status(); +} + +void SyncGRPCServer::TryToRegisterNewGetOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; + return; + } + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); + VLOG(4) << "Create RequestGet status:" << get->Status(); +} + +void SyncGRPCServer::TryToRegisterNewPrefetchOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; + return; + } + RequestPrefetch* prefetch = + new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, + executor_, program_, prefetch_ctx_); + + VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); +} + +void SyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, + const std::string& cq_name, + std::function TryToRegisterNewOne) { + TryToRegisterNewOne(); + + void* tag = NULL; + bool ok = false; + + while (true) { + VLOG(3) << "HandleRequest for " << cq_name << " while in"; + if (!cq->Next(&tag, &ok)) { + LOG(INFO) << cq_name << " CompletionQueue shutdown!"; + break; + } + VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; + + PADDLE_ENFORCE(tag); + + RequestBase* base = reinterpret_cast(tag); + // reference: + // https://github.com/tensorflow/tensorflow/issues/5596 + // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM + // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I + if (!ok) { + LOG(WARNING) << cq_name << " recv no regular event:argument name[" + << base->GetReqName() << "]"; + TryToRegisterNewOne(); + delete base; + continue; + } + + switch (base->Status()) { + case PROCESS: { + VLOG(4) << cq_name << " status:" << base->Status(); + TryToRegisterNewOne(); + base->Process(); + break; + } + case FINISH: { + VLOG(4) << cq_name << " status:" << base->Status(); + delete base; + break; + } + default: { assert(false); } + } + } +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/async_grpc_server.h b/paddle/fluid/operators/detail/async_grpc_server.h new file mode 100644 index 00000000000..870b684cfad --- /dev/null +++ b/paddle/fluid/operators/detail/async_grpc_server.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT +#include + +#include "grpc++/grpc++.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/detail/grpc_service.h" +#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/simple_block_queue.h" + +namespace paddle { +namespace operators { +namespace detail { + +typedef std::pair> + ReceivedMessage; +typedef SimpleBlockQueue ReceivedQueue; + +typedef std::pair MessageWithName; +class RequestBase; + +class SyncGRPCServer final { + public: + explicit SyncGRPCServer(const std::string &address) : address_(address) {} + + void RunAsyncUpdate(); + + void SetScope(framework::Scope *scope) { scope_ = scope; } + + void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } + + void SetProgram(framework::ProgramDesc *program) { program_ = program; } + + void SetExecutor(framework::Executor *executor) { executor_ = executor; } + + void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { + prefetch_ctx_ = prepared; + } + + int GetSelectedPort() { return selected_port_; } + + const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } + + void Push(const std::string &msg_name) { + this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); + } + + void ShutDown(); + + protected: + void HandleRequest(::grpc::ServerCompletionQueue *cq, + const std::string &cq_name, + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(); + void TryToRegisterNewGetOne(); + void TryToRegisterNewPrefetchOne(); + void ShutdownQueue(); + + private: + std::mutex cq_mutex_; + volatile bool is_shut_down_ = false; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; + + GrpcService::AsyncService service_; + std::unique_ptr<::grpc::Server> server_; + + std::string address_; + framework::Scope *scope_; + const platform::DeviceContext *dev_ctx_; + + // client send variable to this queue. + ReceivedQueue var_recv_queue_; + + std::unique_ptr t_send_; + std::unique_ptr t_get_; + std::unique_ptr t_prefetch_; + + framework::ExecutorPrepareContext *prefetch_ctx_; + framework::ProgramDesc *program_; + framework::Executor *executor_; + int selected_port_; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle -- GitLab