diff --git a/paddle/fluid/operators/async_listen_and_serv_op.cc b/paddle/fluid/operators/async_listen_and_serv_op.cc deleted file mode 100644 index 093d44e2d1803a9ea4d33fca84f78c49451c2b11..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/async_listen_and_serv_op.cc +++ /dev/null @@ -1,214 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include // NOLINT -#include - -#include "paddle/fluid/operators/async_listen_and_serv_op.h" - -#include "paddle/utils/StringUtil.h" - -namespace paddle { -namespace operators { - -static void split(const std::string &str, char sep, - std::vector *pieces) { - pieces->clear(); - if (str.empty()) { - return; - } - size_t pos = 0; - size_t next = str.find(sep, pos); - while (next != std::string::npos) { - pieces->push_back(str.substr(pos, next - pos)); - pos = next + 1; - next = str.find(sep, pos); - } - if (!str.substr(pos).empty()) { - pieces->push_back(str.substr(pos)); - } -} - -void RunServer(std::shared_ptr service) { - service->RunAsyncUpdate(); - VLOG(4) << "RunServer thread end"; -} - -static void AsyncExecuteBlock(framework::Executor *executor, - framework::ExecutorPrepareContext *prepared, - framework::Scope *scope) { - framework::Async([&executor, &prepared, &scope]() { - try { - executor->RunPreparedContext(prepared, scope, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - }); -} - -AsyncListenAndServOp::AsyncListenAndServOp( - const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - -int AsyncListenAndServOp::GetSelectedPort() const { - return rpc_service_->GetSelectedPort(); -} - -void AsyncListenAndServOp::Stop() { - rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); - server_thread_->join(); -} - -void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - framework::Scope &recv_scope = scope.NewScope(); - - if (!rpc_service_) { - std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::SyncGRPCServer(endpoint)); - } - - // grad name to block id - std::unordered_map grad_to_id; - std::unordered_map id_to_grad; - - auto grad_map_str = Attr>("grad_to_id"); - for (auto &grad_and_id : grad_map_str) { - std::vector pieces; - split(grad_and_id, ' ', &pieces); - PADDLE_ENFORCE_EQ(pieces.size(), 2); - PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); - int block_id = std::stoi(pieces[1]); - grad_to_id[pieces[0]] = block_id; - id_to_grad[block_id] = pieces[0]; - } - - auto *optimize_block = Attr(kOptimizeBlock); - auto *prefetch_block = Attr(kPrefetchBlock); - auto *program = optimize_block->Program(); - size_t num_blocks = program->Size(); - PADDLE_ENFORCE_GE(num_blocks, 2, - "server program should have at least 2 blocks"); - - framework::Executor executor(dev_place); - std::vector block_list; - for (size_t blkid = 1; blkid < num_blocks; ++blkid) { - if (blkid != static_cast(prefetch_block->ID())) { - block_list.push_back(blkid); - } - } - PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(), - "grad num should be equal to optimize block num"); - auto optimize_prepared = executor.Prepare(*program, block_list); - - std::unordered_map> - grad_to_prepared; - for (size_t i = 0; i < block_list.size(); ++i) { - grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; - } - - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - - // set proper fields for table lookup and update - rpc_service_->SetExecutor(&executor); - VLOG(3) << "prefetch block id is " << prefetch_block->ID(); - auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); - rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); - prefetch_prepared.release(); - rpc_service_->SetProgram(program); - - // start the server listening after all member initialized. - server_thread_.reset(new std::thread(RunServer, rpc_service_)); - VLOG(3) << "wait server thread to become ready..."; - sleep(5); - // Write to a file of server selected port for python use. - std::ofstream port_file; - port_file.open("/tmp/paddle.selected_port"); - port_file << rpc_service_->GetSelectedPort(); - port_file.close(); - - bool exit_flag = false; - while (!exit_flag) { - const detail::ReceivedMessage v = rpc_service_->Get(); - auto recv_var_name = v.first; - if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; - break; - } else { - VLOG(3) << "received grad: " << recv_var_name; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(), - &recv_scope); - // TODO(qiao): explain why - if (var->IsType()) { - var->GetMutable()->mutable_rows()->clear(); - } - } - - if (exit_flag) { - rpc_service_->ShutDown(); - break; - } - } // while(true) -} - -class AsyncListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { - public: - AsyncListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable(); - AddComment(R"DOC( -ListenAndServ operator - -This operator will start a RPC server which can receive variables -from send_op and send back variables to recv_op. -)DOC"); - AddAttr("endpoint", - "(string, default 127.0.0.1:6164)" - "IP address to listen on.") - .SetDefault("127.0.0.1:6164") - .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr>( - "grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", - "a map from grad name to it's optimize block id") - .SetDefault({}); - AddAttr(kOptimizeBlock, - "BlockID to run on server side."); - AddAttr(kPrefetchBlock, - "prefetch block to run on server side."); - AddAttr("Fanin", "How many clients send to this server.") - .SetDefault(1); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(async_listen_and_serv, ops::AsyncListenAndServOp, - ops::AsyncListenAndServOpMaker); diff --git a/paddle/fluid/operators/async_listen_and_serv_op.h b/paddle/fluid/operators/async_listen_and_serv_op.h deleted file mode 100644 index 9df351b929a33f43ab47935fb2ec95232011d4c3..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/async_listen_and_serv_op.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/async_grpc_server.h" - -namespace paddle { -namespace operators { - -constexpr char kOptimizeBlock[] = "OptimizeBlock"; -constexpr char kPrefetchBlock[] = "PrefetchBlock"; - -void RunServer(std::shared_ptr service); - -class AsyncListenAndServOp : public framework::OperatorBase { - public: - AsyncListenAndServOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs); - - int GetSelectedPort() const; - - void Stop() override; - - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override; - - protected: - mutable std::shared_ptr rpc_service_; - mutable std::shared_ptr server_thread_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/detail/async_grpc_server.cc b/paddle/fluid/operators/detail/async_grpc_server.cc deleted file mode 100644 index fb0258f942c9ec1359750a2ca958405bdd8e1609..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detail/async_grpc_server.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/detail/async_grpc_server.h" - -#include -#include - -using ::grpc::ServerAsyncResponseWriter; - -namespace paddle { -namespace operators { -namespace detail { - -enum CallStatus { PROCESS = 0, FINISH }; - -// reference: -// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server -class RequestBase { - public: - explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - const platform::DeviceContext* dev_ctx) - : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { - PADDLE_ENFORCE(cq_); - } - virtual ~RequestBase() {} - virtual void Process() { assert(false); } - - CallStatus Status() { return status_; } - void SetStatus(CallStatus status) { status_ = status; } - virtual std::string GetReqName() { - assert(false); - return ""; - } - - protected: - ::grpc::ServerContext ctx_; - GrpcService::AsyncService* service_; - ::grpc::ServerCompletionQueue* cq_; - CallStatus status_; - const platform::DeviceContext* dev_ctx_; -}; - -class RequestSend final : public RequestBase { - public: - explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx) - : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(true, scope, dev_ctx_)); - int method_id = static_cast(detail::GrpcMethod::kSendVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); - } - - virtual ~RequestSend() {} - - virtual std::string GetReqName() { return request_->Varname(); } - - virtual void Process() { - queue_->Push(std::make_pair(request_->Varname(), request_)); - - sendrecv::VoidMessage reply; - responder_.Finish(reply, ::grpc::Status::OK, this); - status_ = FINISH; - } - - protected: - std::shared_ptr request_; - ReceivedQueue* queue_; - ServerAsyncResponseWriter responder_; -}; - -class RequestGet final : public RequestBase { - public: - explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - framework::Scope* scope, - const platform::DeviceContext* dev_ctx) - : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) { - auto method_id = static_cast(detail::GrpcMethod::kGetVariable); - service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, - cq_, this); - } - - virtual ~RequestGet() {} - - virtual std::string GetReqName() { return request_.varname(); } - - virtual void Process() { - // proc request. - std::string var_name = request_.varname(); - auto* var = scope_->FindVar(var_name); - - ::grpc::ByteBuffer reply; - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); - - responder_.Finish(reply, ::grpc::Status::OK, this); - status_ = FINISH; - } - - protected: - sendrecv::VariableMessage request_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - framework::Scope* scope_; -}; - -class RequestPrefetch final : public RequestBase { - public: - explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - framework::Scope* scope, - const platform::DeviceContext* dev_ctx, - framework::Executor* executor, - framework::ProgramDesc* program, - framework::ExecutorPrepareContext* prefetch_ctx) - : RequestBase(service, cq, dev_ctx), - responder_(&ctx_), - scope_(scope), - executor_(executor), - program_(program), - prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(true, scope, dev_ctx_)); - int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); - } - - virtual ~RequestPrefetch() {} - - virtual std::string GetReqName() { return request_->Varname(); } - - virtual void Process() { - // prefetch process... - ::grpc::ByteBuffer reply; - - std::string var_name = request_->OutVarname(); - VLOG(3) << "prefetch var " << var_name; - auto var_desc = program_->Block(0).FindVar(var_name); - framework::Scope* local_scope = &scope_->NewScope(); - auto* var = local_scope->FindVar(var_name); - InitializeVariable(var, var_desc->GetType()); - executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false); - - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); - - responder_.Finish(reply, ::grpc::Status::OK, this); - status_ = FINISH; - } - - protected: - std::shared_ptr request_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - framework::Scope* scope_; - framework::Executor* executor_; - framework::ProgramDesc* program_; - framework::ExecutorPrepareContext* prefetch_ctx_; -}; - -void SyncGRPCServer::RunAsyncUpdate() { - ::grpc::ServerBuilder builder; - builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), - &selected_port_); - builder.SetMaxSendMessageSize(std::numeric_limits::max()); - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - builder.RegisterService(&service_); - - cq_send_ = builder.AddCompletionQueue(); - cq_get_ = builder.AddCompletionQueue(); - cq_prefetch_ = builder.AddCompletionQueue(); - - server_ = builder.BuildAndStart(); - LOG(INFO) << "Server listening on " << address_ - << " selected port: " << selected_port_; - - std::function send_register = - std::bind(&SyncGRPCServer::TryToRegisterNewSendOne, this); - std::function get_register = - std::bind(&SyncGRPCServer::TryToRegisterNewGetOne, this); - std::function prefetch_register = - std::bind(&SyncGRPCServer::TryToRegisterNewPrefetchOne, this); - - // TODO(wuyi): Run these "HandleRequest" in thread pool - t_send_.reset( - new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this, - cq_send_.get(), "cq_send", send_register))); - t_get_.reset( - new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this, - cq_get_.get(), "cq_get", get_register))); - t_prefetch_.reset(new std::thread( - std::bind(&SyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), - "cq_prefetch", prefetch_register))); - // wait server - server_->Wait(); - t_send_->join(); - t_get_->join(); - t_prefetch_->join(); -} - -void SyncGRPCServer::ShutdownQueue() { - std::unique_lock lock(cq_mutex_); - cq_send_->Shutdown(); - cq_get_->Shutdown(); - cq_prefetch_->Shutdown(); -} - -// This URL explains why shutdown is complicate: -void SyncGRPCServer::ShutDown() { - is_shut_down_ = true; - ShutdownQueue(); - server_->Shutdown(); -} - -void SyncGRPCServer::TryToRegisterNewSendOne() { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; - return; - } - RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, - &var_recv_queue_, dev_ctx_); - VLOG(4) << "Create RequestSend status:" << send->Status(); -} - -void SyncGRPCServer::TryToRegisterNewGetOne() { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; - return; - } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); - VLOG(4) << "Create RequestGet status:" << get->Status(); -} - -void SyncGRPCServer::TryToRegisterNewPrefetchOne() { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; - return; - } - RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_ctx_); - - VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); -} - -void SyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, - const std::string& cq_name, - std::function TryToRegisterNewOne) { - TryToRegisterNewOne(); - - void* tag = NULL; - bool ok = false; - - while (true) { - VLOG(3) << "HandleRequest for " << cq_name << " while in"; - if (!cq->Next(&tag, &ok)) { - LOG(INFO) << cq_name << " CompletionQueue shutdown!"; - break; - } - VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; - - PADDLE_ENFORCE(tag); - - RequestBase* base = reinterpret_cast(tag); - // reference: - // https://github.com/tensorflow/tensorflow/issues/5596 - // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM - // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I - if (!ok) { - LOG(WARNING) << cq_name << " recv no regular event:argument name[" - << base->GetReqName() << "]"; - TryToRegisterNewOne(); - delete base; - continue; - } - - switch (base->Status()) { - case PROCESS: { - VLOG(4) << cq_name << " status:" << base->Status(); - TryToRegisterNewOne(); - base->Process(); - break; - } - case FINISH: { - VLOG(4) << cq_name << " status:" << base->Status(); - delete base; - break; - } - default: { assert(false); } - } - } -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/detail/async_grpc_server.h b/paddle/fluid/operators/detail/async_grpc_server.h deleted file mode 100644 index 870b684cfadc89b877d1301d00b3a95bccb7e1b1..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detail/async_grpc_server.h +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include // NOLINT -#include - -#include "grpc++/grpc++.h" -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/detail/grpc_service.h" -#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" -#include "paddle/fluid/operators/detail/simple_block_queue.h" - -namespace paddle { -namespace operators { -namespace detail { - -typedef std::pair> - ReceivedMessage; -typedef SimpleBlockQueue ReceivedQueue; - -typedef std::pair MessageWithName; -class RequestBase; - -class SyncGRPCServer final { - public: - explicit SyncGRPCServer(const std::string &address) : address_(address) {} - - void RunAsyncUpdate(); - - void SetScope(framework::Scope *scope) { scope_ = scope; } - - void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } - - void SetProgram(framework::ProgramDesc *program) { program_ = program; } - - void SetExecutor(framework::Executor *executor) { executor_ = executor; } - - void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { - prefetch_ctx_ = prepared; - } - - int GetSelectedPort() { return selected_port_; } - - const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } - - void Push(const std::string &msg_name) { - this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); - } - - void ShutDown(); - - protected: - void HandleRequest(::grpc::ServerCompletionQueue *cq, - const std::string &cq_name, - std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(); - void TryToRegisterNewGetOne(); - void TryToRegisterNewPrefetchOne(); - void ShutdownQueue(); - - private: - std::mutex cq_mutex_; - volatile bool is_shut_down_ = false; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; - - GrpcService::AsyncService service_; - std::unique_ptr<::grpc::Server> server_; - - std::string address_; - framework::Scope *scope_; - const platform::DeviceContext *dev_ctx_; - - // client send variable to this queue. - ReceivedQueue var_recv_queue_; - - std::unique_ptr t_send_; - std::unique_ptr t_get_; - std::unique_ptr t_prefetch_; - - framework::ExecutorPrepareContext *prefetch_ctx_; - framework::ProgramDesc *program_; - framework::Executor *executor_; - int selected_port_; -}; - -}; // namespace detail -}; // namespace operators -}; // namespace paddle