diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4253300788462a3704076fc79241a864f2f130a0..a594de67e05acd28ffedc5407beecfaea1281444 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/proto_desc.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/simple_block_queue.h" @@ -89,6 +90,10 @@ class ListenAndServOp : public framework::OperatorBase { auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); + int num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + framework::Executor executor(dev_place); // TODO(typhoonzero): change this to a while_op for every cluster-batch. @@ -132,12 +137,36 @@ class ListenAndServOp : public framework::OperatorBase { rpc_service_->ShutDown(); break; } - try { - executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); + + // put optimize blocks in the thread pool to start run, the last block + // should be global ops. + // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads + // and this will still work. + std::vector> fs; + // block0 contains only listen_and_serv op, start run from block1. + for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { + fs.push_back(framework::Async([&executor, &program, &recv_scope, + blkid]() { + int run_block = blkid; // thread local + try { + executor.Run(*program, &recv_scope, run_block, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); + } + for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait(); + // Run global block at final step, or block1 if there are only 2 blocks + if (num_blocks >= 2) { + try { + executor.Run(*program, &recv_scope, num_blocks - 1, + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } } + // Reset the received sparse variables, the sum operator would not // sum the input sparse variables which rows is empty at the next // mini-batch. diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3d3a6c116eeb39fb7236d0e9707415cdd6b828bd..ad655ee96cee0744e7bedb17163faf7d8d1d8877 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -307,15 +307,57 @@ class DistributeTranspiler: # Iterate through the ops, and if an op and the optimize ops # which located on current pserver are in one set, then # append it into the sub program. - for _, op in enumerate(self.optimize_ops): - for _, opt_op in enumerate(opt_op_on_pserver): - if ufind.is_connected(op, opt_op): - if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint, - default_main_program()) - else: - self._append_pserver_non_opt_ops(optimize_block, op) - break + + # We try to put optimization program run parallelly, assume + # optimization program always looks like: + # + # prevop -> prevop -> opt op -> following op -> following op; -> + # prevop -> prevop -> opt op -> following op -> following op; -> + # global op -> global op + # + # we put operators that can run parallelly to many program blocks. + # in above example, we seperate ops by the ";". Global ops must run + # after all the optimize ops finished. + + global_ops = [] + # HACK: optimization global ops only used to scale beta1 and beta2 + # replace it with dependency engine. + for op in self.optimize_ops: + if op.type == "scale": + for in_name in op.input_arg_names: + if in_name.startswith("beta1_pow_acc") or\ + in_name.startswith("beta2_pow_acc"): + global_ops.append(op) + + def __append_optimize_op__(op, block): + if self._is_opt_op(op): + self._append_pserver_ops(block, op, endpoint, + default_main_program()) + else: + self._append_pserver_non_opt_ops(block, op) + + # append op to the current block + per_opt_block = optimize_block + for _, opt_op in enumerate(opt_op_on_pserver): + for _, op in enumerate(self.optimize_ops): + # optimizer is connected to itself + if ufind.is_connected(op, opt_op) and \ + op not in global_ops: + __append_optimize_op__(op, per_opt_block) + per_opt_block = pserver_program.create_block(0) + + # append global ops + for glb_op in global_ops: + __append_optimize_op__(glb_op, per_opt_block) + + # NOT USED: single block version: + # + # for _, op in enumerate(self.optimize_ops): + # for _, opt_op in enumerate(opt_op_on_pserver): + # if ufind.is_connected(op, opt_op): + # __append_optimize_op__(glb_op, optimize_block) + # break + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", @@ -660,10 +702,22 @@ class DistributeTranspiler: # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - op1_input_names = op1.desc.input_arg_names() + def _append_inname_remove_beta(varname_list): + op_input_names = [] + for in_name in varname_list: + # HACK: remove beta1 and beta2 to avoid let all + # ops connected. + if in_name.startswith("beta2_pow_acc") or \ + in_name.startswith("beta1_pow_acc"): + continue + else: + op_input_names.append(in_name) + return op_input_names + + op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names()) op1_output_names = op1.desc.output_arg_names() - op2_input_names = op2.desc.input_arg_names() + op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names()) op2_output_names = op2.desc.output_arg_names() if set(op1_output_names) & set(op2_input_names) or \