From 8f7c77309d1ae5e34ecf51c6f5729ce2b1e63aa5 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 20 Apr 2018 13:21:36 +0800 Subject: [PATCH] refine listen_and_serv_op --- paddle/fluid/operators/detail/grpc_server.h | 2 +- paddle/fluid/operators/listen_and_serv_op.cc | 116 +++++++++---------- paddle/fluid/operators/listen_and_serv_op.h | 21 +++- paddle/fluid/operators/send_recv_op_test.cc | 2 +- 4 files changed, 75 insertions(+), 66 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index b6110f92ed4..3c113f4ffbc 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -67,7 +67,7 @@ class AsyncGRPCServer final { prefetch_ctx_ = prepared; } - int GetSelectedPort() { return selected_port_; } + int GetSelectedPort() const { return selected_port_; } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a4c925b538e..ec00a959f52 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -27,20 +27,6 @@ void RunServer(std::shared_ptr service) { VLOG(4) << "RunServer thread end"; } -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } -} - static void ParallelExecuteBlocks( const std::vector ¶llel_blkids, framework::Executor *executor, const std::vector> @@ -77,59 +63,37 @@ void ListenAndServOp::Stop() { server_thread_->join(); } -void ListenAndServOp::RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - framework::Scope &recv_scope = scope.NewScope(); - - if (!rpc_service_) { - std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); - } +void ListenAndServOp::PreparePrefetchCtx( + framework::Executor *executor, framework::BlockDesc *prefetch_block, + framework::ProgramDesc *program) const { + // TODO(qiao) set proper fields for table lookup and update + rpc_service_->SetExecutor(executor); + VLOG(3) << "prefetch block id is " << prefetch_block->ID(); + auto prefetch_prepared = executor->Prepare(*program, prefetch_block->ID()); + rpc_service_->SetPrefetchBlkdId(prefetch_block->ID()); + rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); + prefetch_prepared.release(); +} - auto ins = Inputs("X"); +void ListenAndServOp::RunSyncUpdate( + framework::Executor *executor, framework::ProgramDesc *program, + framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const { auto fan_in = Attr("Fanin"); - auto *optimize_block = Attr(kOptimizeBlock); - auto *prefetch_block = Attr(kPrefetchBlock); - auto *program = optimize_block->Program(); + size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); - framework::Executor executor(dev_place); std::vector block_list; for (size_t blkid = 1; blkid < num_blocks; ++blkid) { - if (blkid != static_cast(prefetch_block->ID())) { - block_list.push_back(blkid); - } + block_list.push_back(blkid); } - auto optimize_prepared = executor.Prepare(*program, block_list); + auto optimize_prepared = executor->Prepare(*program, block_list); // Insert placeholder for block0 which holds current op itself. optimize_prepared.insert( optimize_prepared.begin(), std::shared_ptr(nullptr)); - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - // TODO(qiao) set proper fields for table lookup and update - rpc_service_->SetExecutor(&executor); - VLOG(3) << "prefetch block id is " << prefetch_block->ID(); - auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); - rpc_service_->SetPrefetchBlkdId(prefetch_block->ID()); - rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); - prefetch_prepared.release(); - rpc_service_->SetProgram(program); - // start the server listening after all member initialized. - server_thread_.reset(new std::thread(RunServer, rpc_service_)); - VLOG(3) << "wait server thread to become ready..."; - sleep(5); - // Write to a file of server selected port for python use. - std::ofstream port_file; - port_file.open("/tmp/paddle.selected_port"); - port_file << rpc_service_->GetSelectedPort(); - port_file.close(); - bool exit_flag = false; // Record received sparse variables, so that // we could reset those after execute optimize program @@ -170,7 +134,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, break; } - // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads + // NOTE: if is_gpu_place, CUDA kernels are launch by multiple threads // and this will still work. // The optimize blocks which have the same parent ID would run parallel @@ -182,16 +146,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, for (size_t blkid = 2; blkid < num_blocks; ++blkid) { if (blkid != static_cast(prefetch_block->ID())) { if (program->Block(blkid).Parent() != last_parent_blkid) { - ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared, - program, &recv_scope); + ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, + program, recv_scope); parallel_blkids.clear(); last_parent_blkid = program->Block(blkid).Parent(); } parallel_blkids.push_back(blkid); } } - ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared, - program, &recv_scope); + ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, + recv_scope); VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; // Reset the received sparse variables, the sum operator would not @@ -209,6 +173,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, } // while(true) } +static void SavePort(std::shared_ptr rpc_service) { + std::ofstream port_file; + port_file.open("/tmp/paddle.selected_port"); + port_file << rpc_service->GetSelectedPort(); + port_file.close(); +} + +void ListenAndServOp::RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + framework::Scope &recv_scope = scope.NewScope(); + + PADDLE_ENFORCE(!rpc_service_); + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + + auto *optimize_block = Attr(kOptimizeBlock); + auto *prefetch_block = Attr(kPrefetchBlock); + auto *program = optimize_block->Program(); + framework::Executor executor(dev_place); + + // prepare rpc_service + rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); + rpc_service_->SetProgram(program); + PreparePrefetchCtx(&executor, prefetch_block, program); + // start the server listening after all member initialized. + server_thread_.reset(new std::thread(RunServer, rpc_service_)); + VLOG(3) << "wait server thread to become ready..."; + sleep(5); + // Write to a file of server selected port for python use. + SavePort(rpc_service_); + RunSyncUpdate(&executor, program, &recv_scope, prefetch_block); +} + class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { public: ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 9744921cef7..33d15c7d611 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -34,17 +34,26 @@ void RunServer(std::shared_ptr service); class ListenAndServOp : public framework::OperatorBase { public: - ListenAndServOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs); + ListenAndServOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); int GetSelectedPort() const; + void PreparePrefetchCtx(framework::Executor* executor, + framework::BlockDesc* prefetch_block, + framework::ProgramDesc* program) const; + + void RunSyncUpdate(framework::Executor* executor, + framework::ProgramDesc* program, + framework::Scope* recv_scope, + framework::BlockDesc* prefetch_block) const; + void Stop() override; - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override; + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override; protected: mutable std::shared_ptr rpc_service_; diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index a342874f974..81350fee38d 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) { const auto &root_block = program.Block(0); auto *optimize_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block); - // X for server side tensors, RX for received tensers, must be of same shape. + // X for server side tensors, RX for received tensors, must be of same shape. AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); f::AttributeMap attrs; -- GitLab