diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 57867aad4d679f75ea790b65b5773a73586fd96e..5c2979222a8121aeaa82222085aebdcd51f6c603 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -155,16 +155,18 @@ class RequestPrefetch final : public RequestBase { void Process() override { // prefetch process... - std::string varname = request_->OutVarname(); - VLOG(3) << "RequestPrefetch " << varname; + std::string in_var_name = request_->Varname(); + std::string out_var_name = request_->OutVarname(); + VLOG(3) << "in_var_name: " << in_var_name + << " RequestPrefetch: " << out_var_name; auto scope = request_->GetMutableLocalScope(); - auto invar = scope->FindVar(varname); + auto invar = scope->FindVar(in_var_name); framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar); + request_handler_->Handle(in_var_name, scope, invar, &outvar); - SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), + SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), &reply_); responder_.Finish(reply_, ::grpc::Status::OK, reinterpret_cast(static_cast(req_id_))); diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index a1f9ba15e656a686c4bb0d81cf00dea120b8c0ad..e6a33903ad856a0b38115af48c30d92a39d9b65c 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -99,11 +99,17 @@ void StartServer() { framework::Executor exe(place); platform::CPUDeviceContext ctx(place); auto* block = AppendPrefetchBlcok(&program); - auto prepared = exe.Prepare(program, block->ID()); + std::string in_var_name("ids"); + std::vector prefetch_block_ids{block->ID()}; + auto prepared = exe.Prepare(program, prefetch_block_ids); InitTensorsOnServer(&scope, &place, 10); + std::unordered_map> + prefetch_var_name_to_prepared; + prefetch_var_name_to_prepared[in_var_name] = prepared[0]; g_req_handler->SetProgram(&program); - g_req_handler->SetPrefetchPreparedCtx(std::move(prepared)); + g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetDevCtx(&ctx); g_req_handler->SetScope(&scope); g_req_handler->SetExecutor(&exe); diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index d74206aaba6a79ee06475985e642221bd84d9382..373a6aaa09c06899e36fdcab8c2c745a377924a3 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -57,9 +57,12 @@ class RequestHandler { void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } void SetProgram(framework::ProgramDesc* program) { program_ = program; } void SetExecutor(framework::Executor* executor) { executor_ = executor; } + + // Used for dist lookup table prefetch void SetPrefetchPreparedCtx( - std::unique_ptr prepared) { - prefetch_ctx_.reset(prepared.release()); + std::unordered_map< + std::string, std::shared_ptr>* g) { + prefetch_var_name_to_prepared_ctx_ = g; } // Used for async. @@ -75,9 +78,6 @@ class RequestHandler { bool sync_mode() { return sync_mode_; } framework::Scope* scope() { return scope_; } const platform::DeviceContext* dev_ctx() { return dev_ctx_; } - framework::ExecutorPrepareContext* prefetch_ctx() { - return prefetch_ctx_.get(); - } framework::ProgramDesc* program() { return program_; } framework::Executor* executor() { return executor_; } @@ -106,12 +106,17 @@ class RequestHandler { framework::Executor* executor_; framework::Scope* scope_; framework::ProgramDesc* program_; - std::unique_ptr prefetch_ctx_; + + // used for distribute lookup table prefetch + std::unordered_map>* + prefetch_var_name_to_prepared_ctx_; // Used for async. std::unordered_map>* grad_to_prepared_ctx_; + RPCServer* rpc_server_; }; diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 9473dce55029f2a4e0987ab8f6f5e7205d7fff47..dc28740bf04e5b17727b02741a20a8c09d2bbc88 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -111,7 +111,8 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, auto var_desc = program_->Block(0).FindVar(varname); *outvar = scope->FindVar(varname); InitializeVariable(*outvar, var_desc->GetType()); - executor_->RunPreparedContext(prefetch_ctx_.get(), scope); + executor_->RunPreparedContext( + (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); return true; } diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 272d7905ebe81977975de723100da0b3f1c9823f..4d35caff2c733fd0496fa4d6530a4e1c90572682 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -89,16 +89,19 @@ void ListenAndServOp::SavePort() const { rpc_service_->SavePort(); } -void ListenAndServOp::RunSyncLoop(framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *recv_scope, - framework::BlockDesc *prefetch_block) const { +void ListenAndServOp::RunSyncLoop( + framework::Executor *executor, framework::ProgramDesc *program, + framework::Scope *recv_scope, + const std::vector &prefetch_block_id_list) const { + // FIXME(qiao) run should not run the block to do prefetch, currently prefetch + // block + // can only be at the last blocks of the program size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); std::vector block_list; - for (size_t blkid = 1; blkid < num_blocks; ++blkid) { + for (size_t blkid = 1; blkid < prefetch_block_id_list[0]; ++blkid) { block_list.push_back(blkid); } auto optimize_prepared = executor->Prepare(*program, block_list); @@ -128,16 +131,14 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, std::vector parallel_blkids; parallel_blkids.push_back(1); double ts = detail::GetTimestamp(); - for (size_t blkid = 2; blkid < num_blocks; ++blkid) { - if (blkid != static_cast(prefetch_block->ID())) { - if (program->Block(blkid).Parent() != last_parent_blkid) { - ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, - program, recv_scope); - parallel_blkids.clear(); - last_parent_blkid = program->Block(blkid).Parent(); - } - parallel_blkids.push_back(blkid); + for (size_t blkid = 2; blkid < prefetch_block_id_list[0]; ++blkid) { + if (program->Block(blkid).Parent() != last_parent_blkid) { + ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, + program, recv_scope); + parallel_blkids.clear(); + last_parent_blkid = program->Block(blkid).Parent(); } + parallel_blkids.push_back(blkid); } ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, recv_scope); @@ -203,18 +204,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, } // while(true) } -static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope, - platform::DeviceContext *dev_ctx, - framework::Executor *executor, - framework::ProgramDesc *program, - framework::ExecutorPrepareContext *prefetch_ctx, - detail::RPCServer *rpc_server) { +static void FillRequestCtx( + detail::RequestHandler *h, framework::Scope *scope, + platform::DeviceContext *dev_ctx, framework::Executor *executor, + framework::ProgramDesc *program, + std::unordered_map> + *prefetch_ctx, + detail::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); h->SetExecutor(executor); h->SetProgram(program); - h->SetPrefetchPreparedCtx( - std::unique_ptr(prefetch_ctx)); + h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetRPCServer(rpc_server); } @@ -248,18 +250,41 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); - auto grad_to_block_id_str = Attr>(kPrefetchBlock); - framework::BlockDesc *prefetch_block = nullptr; auto *program = optimize_block->Program(); framework::Executor executor(dev_place); // prepare for prefetch - VLOG(3) << "prefetch block id is " << prefetch_block->ID(); - auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); + std::vector prefetch_block_id_list; + std::unordered_map block_id_to_prefetch_var_name; + + auto prefetch_var_name_to_block_id_str = + Attr>(kPrefetchVarNameToBlockId); + for (const auto &prefetch_var_name_and_id : + prefetch_var_name_to_block_id_str) { + std::vector pieces; + split(prefetch_var_name_and_id, ':', &pieces); + VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; + PADDLE_ENFORCE_EQ(pieces.size(), 2); + + int block_id = std::stoi(pieces[1]); + prefetch_block_id_list.push_back(block_id); + block_id_to_prefetch_var_name[block_id] = pieces[0]; + } + + auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list); + + std::unordered_map> + prefetch_var_name_to_prepared_ctx; + for (int i = 0; i < prefetch_block_id_list.size(); ++i) { + auto block_id = prefetch_block_id_list[i]; + auto prefetch_var_name = block_id_to_prefetch_var_name[block_id]; + prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; + } auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, - &dev_ctx, &executor, program, prefetch_prepared.release(), - rpc_service_.get()); + &dev_ctx, &executor, program, + &prefetch_var_name_to_prepared_ctx, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); @@ -277,7 +302,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // Write to a file of server selected port for python use. SavePort(); if (sync_mode) { - RunSyncLoop(&executor, program, &recv_scope, prefetch_block); + RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list); } else { RunAsyncLoop(&executor, program); } @@ -303,7 +328,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr(kOptimizeBlock, "BlockID to run on server side."); - AddAttr>(kPrefetchBlock, + AddAttr>(kPrefetchVarNameToBlockId, "prefetch block to run on server side."); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index db3582d9b45746fe47bff4b8f3f7ddca0762121a..46c3a19e20b3f2dd970a672bb99f98e83d3e25bf 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -30,7 +31,7 @@ namespace paddle { namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; -constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id"; +constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; void RunServer(std::shared_ptr service); @@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase { void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, - framework::BlockDesc* prefetch_block) const; + const std::vector& prefetch_block_id_list) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program) const;