diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 49e191b7bb9e97c91ae2e536550aebd6e91c6fce..10064357eace09e9efe5678fdf23267fe1653824 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -27,6 +27,36 @@ void RunServer(std::shared_ptr service) { VLOG(4) << "RunServer thread end"; } +static void split(const std::string &str, char sep, + std::vector *pieces) { + pieces->clear(); + if (str.empty()) { + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } +} + +static void AsyncExecuteBlock(framework::Executor *executor, + framework::ExecutorPrepareContext *prepared, + framework::Scope *scope) { + framework::Async([&executor, &prepared, &scope]() { + try { + executor->RunPreparedContext(prepared, scope, false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + }); +} + static void ParallelExecuteBlocks( const std::vector ¶llel_blkids, framework::Executor *executor, const std::vector> @@ -168,12 +198,82 @@ void ListenAndServOp::RunSyncUpdate( } // while(true) } +void ListenAndServOp::RunAsyncUpdate( + framework::Executor *executor, framework::ProgramDesc *program, + framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const { + // grad name to block id + std::unordered_map grad_to_id; + std::unordered_map id_to_grad; + + auto grad_map_str = Attr>("grad_map"); + for (auto &grad_and_id : grad_map_str) { + std::vector pieces; + split(grad_and_id, ' ', &pieces); + PADDLE_ENFORCE_EQ(pieces.size(), 2); + PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); + int block_id = std::stoi(pieces[1]); + grad_to_id[pieces[0]] = block_id; + id_to_grad[block_id] = pieces[0]; + } + size_t num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + + std::vector block_list; + for (size_t blkid = 1; blkid < num_blocks; ++blkid) { + if (blkid != static_cast(prefetch_block->ID())) { + block_list.push_back(blkid); + } + } + PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(), + "grad num should be equal to optimize block num"); + auto optimize_prepared = executor->Prepare(*program, block_list); + + std::unordered_map> + grad_to_prepared; + for (size_t i = 0; i < block_list.size(); ++i) { + grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; + } + + bool exit_flag = false; + while (!exit_flag) { + const detail::ReceivedMessage v = rpc_service_->Get(); + auto recv_var_name = v.first; + if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; + break; + } else { + VLOG(3) << "received grad: " << recv_var_name; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(), + recv_scope); + // TODO(qiao): explain why + if (var->IsType()) { + var->GetMutable()->mutable_rows()->clear(); + } + } + + if (exit_flag) { + rpc_service_->ShutDown(); + break; + } + } // while(true) +} + void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); + bool sync_mode = Attr("sync_mode"); + PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); @@ -201,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, sleep(5); // Write to a file of server selected port for python use. SavePort(rpc_service_); - RunSyncUpdate(&executor, program, &recv_scope, prefetch_block); + if (sync_mode) { + RunSyncUpdate(&executor, program, &recv_scope, prefetch_block); + } else { + RunAsyncUpdate(&executor, program, &recv_scope, prefetch_block); + } } class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { @@ -220,6 +324,12 @@ from send_op and send back variables to recv_op. "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr>( + "grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", + "a map from grad name to it's optimize block id") + .SetDefault({}); + AddAttr("sync_mode", "if works at sync_mode or not") + .SetDefault(false); AddAttr(kOptimizeBlock, "BlockID to run on server side."); AddAttr(kPrefetchBlock, diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 68c8da4d382221f662932789e67ce2aacf18e1a3..9b915a594105483a5e72c3396287928337562d5c 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase { framework::Scope* recv_scope, framework::BlockDesc* prefetch_block) const; + void RunAsyncUpdate(framework::Executor* executor, + framework::ProgramDesc* program, + framework::Scope* recv_scope, + framework::BlockDesc* prefetch_block) const; + void Stop() override; void RunImpl(const framework::Scope& scope,