From 1a43828780943569b558043bac4c0170e5d962a1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 18 Apr 2018 23:42:52 +0800 Subject: [PATCH] implement main logic --- .../operators/async_listen_and_serv_op.cc | 109 +++++++----------- 1 file changed, 42 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/async_listen_and_serv_op.cc b/paddle/fluid/operators/async_listen_and_serv_op.cc index 4d66f9853a9..ec0ddedf3d1 100644 --- a/paddle/fluid/operators/async_listen_and_serv_op.cc +++ b/paddle/fluid/operators/async_listen_and_serv_op.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/fluid/operators/async_listen_and_serv_op.h" +#include "paddle/utils/StringUtil.h" + namespace paddle { namespace operators { @@ -27,46 +29,18 @@ void RunServer(std::shared_ptr service) { VLOG(4) << "RunServer thread end"; } -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } -} - -static void ParallelExecuteBlocks( - const std::vector ¶llel_blkids, framework::Executor *executor, - const std::vector> - &prepared, - framework::ProgramDesc *program, framework::Scope *scope) { - std::vector> fs; - for (size_t idx : parallel_blkids) { - fs.push_back( - framework::Async([&executor, &prepared, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - executor->RunPreparedContext(prepared[run_block].get(), scope, - false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); - } - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); +static void AsyncExecuteBlock(framework::Executor *executor, + framework::ExecutorPrepareContext *prepared, + framework::Scope *scope) { + framework::Async([&executor, &prepared, &scope]() { + try { + executor->RunPreparedContext(prepared, scope, false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + }); } -static void AsyncExecuteBlock( - size_t block_id, framework::Executor *executor, - std::shared_ptr ctx, - framework::ProgramDesc *program, framework::Scope *scope) {} - AsyncListenAndServOp::AsyncListenAndServOp( const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, @@ -93,6 +67,21 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_.reset(new detail::SyncGRPCServer(endpoint)); } + // grad name to block id + std::unordered_map grad_to_id; + std::unordered_map id_to_grad; + + auto grad_map_str = Attr>("grad_map"); + for (auto &grad_and_id : grad_map_str) { + std::vector pieces; + paddle::str::split(grad_and_id, ' ', &pieces); + PADDLE_ENFORCE_EQ(pieces.size(), 2); + PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); + int block_id = std::stoi(pieces[1]); + grad_to_id[pieces[0]] = block_id; + id_to_grad[block_id] = pieces[0]; + } + auto *optimize_block = Attr(kOptimizeBlock); auto *prefetch_block = Attr(kPrefetchBlock); auto *program = optimize_block->Program(); @@ -108,10 +97,13 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, } } auto optimize_prepared = executor.Prepare(*program, block_list); - // Insert placeholder for block0 which holds current op itself. - optimize_prepared.insert( - optimize_prepared.begin(), - std::shared_ptr(nullptr)); + + std::unordered_map> + grad_to_prepared; + for (size_t i = 0; i < block_list.size(); ++i) { + grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; + } rpc_service_->SetScope(&recv_scope); rpc_service_->SetDevCtx(&dev_ctx); @@ -122,6 +114,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); prefetch_prepared.release(); rpc_service_->SetProgram(program); + // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); VLOG(3) << "wait server thread to become ready..."; @@ -133,9 +126,6 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, port_file.close(); bool exit_flag = false; - // Record received sparse variables, so that - // we could reset those after execute optimize program - std::vector sparse_vars; while (!exit_flag) { const detail::ReceivedMessage v = rpc_service_->Get(); auto recv_var_name = v.first; @@ -150,36 +140,17 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, LOG(ERROR) << "Can not find server side var: " << recv_var_name; PADDLE_THROW("Can not find server side var"); } + AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(), + &recv_scope); if (var->IsType()) { - sparse_vars.push_back(var); + var->GetMutable()->mutable_rows()->clear(); } - AsyncExecuteBlock(); } if (exit_flag) { rpc_service_->ShutDown(); break; } - - // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads - // and this will still work. - - // The optimize blocks which have the same parent ID would run parallel - // TODO(Yancey1989): need to use ParallelExecutor for future - int32_t last_parent_blkid = program->Block(1).Parent(); - - VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; - - // Reset the received sparse variables, the sum operator would not - // sum the input sparse variables which rows is empty at the next - // mini-batch. - // TODO(Yancey1989): move the reset action into an operator, we couldn't - // have any hide logic in the operator. - for (auto &var : sparse_vars) { - var->GetMutable()->mutable_rows()->clear(); - } - // FIXME(typhoonzero): use another condition to sync wait clients get. - sparse_vars.clear(); } // while(true) } @@ -199,6 +170,10 @@ from send_op and send back variables to recv_op. "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr>( + "grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", + "a map from grad name to it's optimize block id") + .SetDefault({}); AddAttr(kOptimizeBlock, "BlockID to run on server side."); AddAttr(kPrefetchBlock, -- GitLab