diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 64c06687b6b905186d4efcc8441d3abef6323d53..16a118090ba9cfd50b4b03484983f9fc73cf7973 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -279,6 +279,21 @@ std::unique_ptr Executor::Prepare( return std::unique_ptr(ctx); } +std::vector> Executor::Prepare( + const ProgramDesc& program, const std::vector& block_ids) { + std::vector> result; + for (auto& bid : block_ids) { + auto* ctx = new ExecutorPrepareContext(program, bid); + PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); + auto& block = program.Block(bid); + for (auto& op_desc : block.AllOps()) { + ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); + } + result.push_back(std::shared_ptr(ctx)); + } + return result; +} + void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope, bool create_vars) { auto& block = ctx->prog_.Block(ctx->block_id_); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 7173c51c95e04ad3095f01bb24923a7a3341c517..d7c99165f0c9d3b1ae11a3b4753a61e8118f7b52 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -61,6 +61,9 @@ class Executor { static std::unique_ptr Prepare( const ProgramDesc& program, int block_id); + static std::vector> Prepare( + const ProgramDesc& program, const std::vector& block_ids); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, bool create_vars = true); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 6cc468e449e3b6dae8110b00a1cc4dff2f5df7e7..91a1f226cd0c96f675bdd59dca809c43b0cedd4f 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -45,20 +45,23 @@ static void CreateTensorFromMessageType(framework::Variable *var, } } -static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, - framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *scope) { +static void ParallelExecuteBlocks( + const std::vector ¶llel_blkids, framework::Executor *executor, + const std::vector> + &prepared, + framework::ProgramDesc *program, framework::Scope *scope) { std::vector> fs; for (size_t idx : parallel_blkids) { - fs.push_back(framework::Async([&executor, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - executor->Run(*program, scope, run_block, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); + fs.push_back( + framework::Async([&executor, &prepared, &program, &scope, idx]() { + int run_block = idx; // thread local + try { + executor->RunPreparedContext(prepared[run_block].get(), scope, + false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); } for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } @@ -101,6 +104,13 @@ class ListenAndServOp : public framework::OperatorBase { "server program should have at least 2 blocks"); framework::Executor executor(dev_place); + std::vector block_list; + for (size_t blkid = 1; blkid < num_blocks; ++blkid) + block_list.push_back(blkid); + auto prepared = executor.Prepare(*program, block_list); + prepared.insert( + prepared.begin(), + std::shared_ptr(nullptr)); // TODO(qiao) set proper fields for table lookup and update rpc_service_->SetExecutor(&executor); @@ -160,14 +170,15 @@ class ListenAndServOp : public framework::OperatorBase { for (size_t blkid = 2; blkid < num_blocks; ++blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) { for (size_t idx : parallel_blkids) VLOG(3) << idx; - ParallelExecuteBlocks(parallel_blkids, &executor, program, + ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, &recv_scope); parallel_blkids.clear(); last_parent_blkid = program->Block(blkid).Parent(); } parallel_blkids.push_back(blkid); } - ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); + ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, + &recv_scope); VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; @@ -181,7 +192,8 @@ class ListenAndServOp : public framework::OperatorBase { var->GetMutable()->mutable_rows()->clear(); } rpc_service_->SetCond(1); - // FIXME(typhoonzero): use another condition to sync wait clients get. + // NOTE: does not consider barrier request retry in here, we may use + // global barrier id to resolve this. rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true)