From b8f4c8599e49cbe997d902d615705fad5a67c20f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 16 Mar 2018 14:31:38 +0800 Subject: [PATCH] pserver runs in parallel --- paddle/fluid/operators/listen_and_serv_op.cc | 33 +++++++-- python/paddle/fluid/distribute_transpiler.py | 77 +++++++++++++++++--- 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 42533007884..6fb17470a3b 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/proto_desc.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/simple_block_queue.h" @@ -89,6 +90,7 @@ class ListenAndServOp : public framework::OperatorBase { auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); + int num_blocks = program->Size(); framework::Executor executor(dev_place); // TODO(typhoonzero): change this to a while_op for every cluster-batch. @@ -132,12 +134,33 @@ class ListenAndServOp : public framework::OperatorBase { rpc_service_->ShutDown(); break; } - try { - executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); + + // put optimize blocks in the thread pool to start run, the last block + // should be global ops. + // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads + // and this will still work. + std::vector> fs; + for (int blkid = 0; blkid < num_blocks - 1; ++blkid) { + fs.push_back(framework::Async([&]() { + try { + executor.Run(*program, &recv_scope, blkid, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); } + for (int blkid = 0; blkid < num_blocks - 1; ++blkid) fs[blkid].wait(); + // Run global block at final step + if (num_blocks > 2) { + try { + executor.Run(*program, &recv_scope, num_blocks - 1, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + } + // Reset the received sparse variables, the sum operator would not // sum the input sparse variables which rows is empty at the next // mini-batch. diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3d3a6c116ee..cdde3512960 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -307,15 +307,58 @@ class DistributeTranspiler: # Iterate through the ops, and if an op and the optimize ops # which located on current pserver are in one set, then # append it into the sub program. - for _, op in enumerate(self.optimize_ops): - for _, opt_op in enumerate(opt_op_on_pserver): - if ufind.is_connected(op, opt_op): - if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint, - default_main_program()) - else: - self._append_pserver_non_opt_ops(optimize_block, op) - break + + # We try to put optimization program run parallelly, assume + # optimization program always looks like: + # + # prevop -> prevop -> opt op -> following op -> following op; -> + # prevop -> prevop -> opt op -> following op -> following op; -> + # global op -> global op + # + # we put operators that can run parallelly to many program blocks. + # in above example, we seperate ops by the ";". Global ops must run + # after all the optimize ops finished. + + global_ops = [] + # HACK: optimization global ops only used to scale beta1 and beta2 + # replace it with dependency engine. + for op in self.optimize_ops: + if op.type == "scale": + for in_name in op.input_arg_names: + if in_name.startswith("beta1_pow_acc") or\ + in_name.startswith("beta2_pow_acc"): + global_ops.append(op) + print("##### global ops ", global_ops) + + def __append_optimize_op__(op, block): + if self._is_opt_op(op): + self._append_pserver_ops(block, op, endpoint, + default_main_program()) + else: + self._append_pserver_non_opt_ops(block, op) + + # append op to the current block + per_opt_block = optimize_block + for _, opt_op in enumerate(opt_op_on_pserver): + for _, op in enumerate(self.optimize_ops): + # optimizer is connected to itself + if ufind.is_connected(op, opt_op) and \ + op not in global_ops: + __append_optimize_op__(op, per_opt_block) + per_opt_block = pserver_program.create_block(0) + + # append global ops + for glb_op in global_ops: + __append_optimize_op__(glb_op, per_opt_block) + + # NOT USED: single block version: + # + # for _, op in enumerate(self.optimize_ops): + # for _, opt_op in enumerate(opt_op_on_pserver): + # if ufind.is_connected(op, opt_op): + # __append_optimize_op__(glb_op, optimize_block) + # break + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", @@ -660,10 +703,22 @@ class DistributeTranspiler: # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - op1_input_names = op1.desc.input_arg_names() + def _append_inname_remove_beta(varname_list): + op_input_names = [] + for in_name in varname_list: + # HACK: remove beta1 and beta2 to avoid let all + # ops connected. + if in_name.startswith("beta2_pow_acc") or \ + in_name.startswith("beta1_pow_acc"): + continue + else: + op_input_names.append(in_name) + return op_input_names + + op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names()) op1_output_names = op1.desc.output_arg_names() - op2_input_names = op2.desc.input_arg_names() + op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names()) op2_output_names = op2.desc.output_arg_names() if set(op1_output_names) & set(op2_input_names) or \ -- GitLab