提交 b8f4c859 编写于 作者: T typhoonzero

pserver runs in parallel

上级 128adf53
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/simple_block_queue.h"
...@@ -89,6 +90,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -89,6 +90,7 @@ class ListenAndServOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
int num_blocks = program->Size();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
...@@ -132,12 +134,33 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -132,12 +134,33 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
break; break;
} }
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ // put optimize blocks in the thread pool to start run, the last block
false /*create_local_scope*/, false /*create_vars*/); // should be global ops.
} catch (std::exception &e) { // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
LOG(ERROR) << "run sub program error " << e.what(); // and this will still work.
std::vector<std::future<void>> fs;
for (int blkid = 0; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async([&]() {
try {
executor.Run(*program, &recv_scope, blkid,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
} }
for (int blkid = 0; blkid < num_blocks - 1; ++blkid) fs[blkid].wait();
// Run global block at final step
if (num_blocks > 2) {
try {
executor.Run(*program, &recv_scope, num_blocks - 1,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
// mini-batch. // mini-batch.
......
...@@ -307,15 +307,58 @@ class DistributeTranspiler: ...@@ -307,15 +307,58 @@ class DistributeTranspiler:
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver): # We try to put optimization program run parallelly, assume
if ufind.is_connected(op, opt_op): # optimization program always looks like:
if self._is_opt_op(op): #
self._append_pserver_ops(optimize_block, op, endpoint, # prevop -> prevop -> opt op -> following op -> following op; ->
default_main_program()) # prevop -> prevop -> opt op -> following op -> following op; ->
else: # global op -> global op
self._append_pserver_non_opt_ops(optimize_block, op) #
break # we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
print("##### global ops ", global_ops)
def __append_optimize_op__(op, block):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(block, op)
# append op to the current block
per_opt_block = optimize_block
for _, opt_op in enumerate(opt_op_on_pserver):
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \
op not in global_ops:
__append_optimize_op__(op, per_opt_block)
per_opt_block = pserver_program.create_block(0)
# append global ops
for glb_op in global_ops:
__append_optimize_op__(glb_op, per_opt_block)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
...@@ -660,10 +703,22 @@ class DistributeTranspiler: ...@@ -660,10 +703,22 @@ class DistributeTranspiler:
# If one op's input is another op's output or # If one op's input is another op's output or
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
op1_input_names = op1.desc.input_arg_names() def _append_inname_remove_beta(varname_list):
op_input_names = []
for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names() op1_output_names = op1.desc.output_arg_names()
op2_input_names = op2.desc.input_arg_names() op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names() op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \ if set(op1_output_names) & set(op2_input_names) or \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册