diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 96d9b49c874ad7b82c35863e97b2d7d45ab9adfd..16a118090ba9cfd50b4b03484983f9fc73cf7973 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -279,7 +279,7 @@ std::unique_ptr Executor::Prepare( return std::unique_ptr(ctx); } -std::vector> Prepare( +std::vector> Executor::Prepare( const ProgramDesc& program, const std::vector& block_ids) { std::vector> result; for (auto& bid : block_ids) { diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 6094c066f94512dd282efe581d24d4771ee632c6..d4b0fa3aa18d6147030f3a92928f0de5369daf79 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -54,20 +54,24 @@ static void CreateTensorFromMessageType(framework::Variable *var, } } -static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, - framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *scope) { +static void ParallelExecuteBlocks( + const std::vector ¶llel_blkids, framework::Executor *executor, + const std::vector> + &prepared, + framework::ProgramDesc *program, framework::Scope *scope) { std::vector> fs; for (size_t idx : parallel_blkids) { - fs.push_back(framework::Async([&executor, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - executor->Run(*program, scope, run_block, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); + fs.push_back( + framework::Async([&executor, &prepared, &program, &scope, idx]() { + int run_block = idx; // thread local + try { + // executor->Run(*program, scope, run_block, false, false); + executor->RunPreparedContext(prepared[run_block].get(), scope, + false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); } for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } @@ -105,15 +109,18 @@ class ListenAndServOp : public framework::OperatorBase { auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); - int num_blocks = program->Size(); + size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); framework::Executor executor(dev_place); std::vector block_list; - for (int blkid = 1; blkid < num_blocks; ++blkid) + for (size_t blkid = 1; blkid < num_blocks; ++blkid) block_list.push_back(blkid); auto prepared = executor.Prepare(*program, block_list); + prepared.insert( + prepared.begin(), + std::shared_ptr(nullptr)); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; @@ -161,21 +168,22 @@ class ListenAndServOp : public framework::OperatorBase { // The optimize blocks which have the same parent ID would run parallel // TODO(Yancey1989): need to use ParallelExecutor for future - size_t last_parent_blkid = program->Block(1).Parent(); + int32_t last_parent_blkid = program->Block(1).Parent(); std::vector parallel_blkids; parallel_blkids.push_back(1); double ts = detail::GetTimestamp(); for (size_t blkid = 2; blkid < num_blocks; ++blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) { for (size_t idx : parallel_blkids) VLOG(3) << idx; - ParallelExecuteBlocks(parallel_blkids, &executor, program, + ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, &recv_scope); parallel_blkids.clear(); last_parent_blkid = program->Block(blkid).Parent(); } parallel_blkids.push_back(blkid); } - ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); + ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, + &recv_scope); VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;