From 1ab4fcb5e705c5e03c0fea3fbca3c00b5e67ff85 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 26 Mar 2018 19:55:46 +0800 Subject: [PATCH] prepare pserver executor --- paddle/fluid/framework/executor.cc | 15 +++++++++++++ paddle/fluid/framework/executor.h | 3 +++ paddle/fluid/operators/listen_and_serv_op.cc | 22 ++++++++++++-------- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0b171e1dcfa..5279eb42cd5 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -279,6 +279,21 @@ std::unique_ptr Executor::Prepare( return std::unique_ptr(ctx); } +std::vector> Prepare( + const ProgramDesc& program, const std::vector& block_ids) { + std::vector> result; + for (auto& bid : block_ids) { + auto* ctx = new ExecutorPrepareContext(program, bid); + PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); + auto& block = program.Block(bid); + for (auto& op_desc : block.AllOps()) { + ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); + } + result.push_back(std::shared_ptr(ctx)); + } + return result; +} + void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope, bool create_vars) { auto& block = ctx->prog_.Block(ctx->block_id_); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index d8dd82469af..756f3c7e5ad 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -60,6 +60,9 @@ class Executor { static std::unique_ptr Prepare( const ProgramDesc& program, int block_id); + static std::vector> Prepare( + const ProgramDesc& program, const std::vector& block_ids); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, bool create_vars = true); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 08b83375dd5..6bae993f612 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -93,6 +93,10 @@ class ListenAndServOp : public framework::OperatorBase { "server program should have at least 2 blocks"); framework::Executor executor(dev_place); + std::vector block_list; + for (int blkid = 1; blkid < num_blocks; ++blkid) + block_list.push_back(blkid); + auto prepared = executor.Prepare(*program, block_list); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; @@ -143,11 +147,12 @@ class ListenAndServOp : public framework::OperatorBase { std::vector> fs; // block0 contains only listen_and_serv op, start run from block1. for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { - fs.push_back( - framework::Async([&executor, &program, &recv_scope, blkid]() { + fs.push_back(framework::Async( + [&executor, &program, &recv_scope, &prepared, blkid]() { int run_block = blkid; // thread local try { - executor.Run(*program, &recv_scope, run_block, false, false); + executor.RunPreparedContext(prepared[run_block].get(), + &recv_scope, false, false); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } @@ -157,7 +162,9 @@ class ListenAndServOp : public framework::OperatorBase { // Run global block at final step, or block1 if there are only 2 blocks if (num_blocks >= 2) { try { - executor.Run(*program, &recv_scope, num_blocks - 1, false, false); + // executor.Run(program, &recv_scope, num_blocks - 1, false, false); + executor.RunPreparedContext(prepared[num_blocks - 1].get(), + &recv_scope, false, false); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } @@ -172,14 +179,11 @@ class ListenAndServOp : public framework::OperatorBase { var->GetMutable()->mutable_rows()->clear(); } rpc_service_->SetCond(1); - // FIXME(typhoonzero): use another condition to sync wait clients get. + // NOTE: does not consider barrier request retry in here, we may use + // global barrier id to resolve this. rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true) - - // for (int i = 0; i < num_blocks; ++i) { - // delete blk_ctx_list[i]; - // } } protected: -- GitLab