提交 374f1ca3 编写于 作者: Y Yancey 提交者: gongweibao

Fix dist error with lr decay layer (#9489)

Fix dist error with lr decay layer
上级 f0af1398
...@@ -54,6 +54,24 @@ static void CreateTensorFromMessageType(framework::Variable *var, ...@@ -54,6 +54,24 @@ static void CreateTensorFromMessageType(framework::Variable *var,
} }
} }
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
executor->Run(*program, scope, run_block, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
ListenAndServOp(const std::string &type, ListenAndServOp(const std::string &type,
...@@ -135,34 +153,27 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -135,34 +153,27 @@ class ListenAndServOp : public framework::OperatorBase {
break; break;
} }
// put optimize blocks in the thread pool to start run, the last block
// should be global ops.
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
// and this will still work. // and this will still work.
std::vector<std::future<void>> fs; // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
size_t last_parent_blkid = program->Block(1).Parent();
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1. for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
fs.push_back( for (size_t idx : parallel_blkids) VLOG(3) << idx;
framework::Async([&executor, &program, &recv_scope, blkid]() { ParallelExecuteBlocks(parallel_blkids, &executor, program,
int run_block = blkid; // thread local &recv_scope);
try { parallel_blkids.clear();
executor.Run(*program, &recv_scope, run_block, false, false); last_parent_blkid = program->Block(blkid).Parent();
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait();
// Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) {
try {
executor.Run(*program, &recv_scope, num_blocks - 1, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
} }
parallel_blkids.push_back(blkid);
} }
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts; VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
...@@ -178,10 +189,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -178,10 +189,6 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->WaitClientGet(fan_in); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
// for (int i = 0; i < num_blocks; ++i) {
// delete blk_ctx_list[i];
// }
} }
protected: protected:
......
...@@ -338,15 +338,24 @@ class DistributeTranspiler: ...@@ -338,15 +338,24 @@ class DistributeTranspiler:
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
append_block = optimize_block
# append lr decay ops to the child block if exits
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(append_block, op)
append_block = pserver_program.create_block(append_block.idx)
# append op to the current block # append op to the current block
per_opt_block = optimize_block per_opt_block = append_block
for _, opt_op in enumerate(opt_op_on_pserver): for _, opt_op in enumerate(opt_op_on_pserver):
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \ if ufind.is_connected(op, opt_op) and \
op not in global_ops: op not in global_ops:
__append_optimize_op__(op, per_opt_block) __append_optimize_op__(op, per_opt_block)
per_opt_block = pserver_program.create_block(0) per_opt_block = pserver_program.create_block(append_block.idx)
# append global ops # append global ops
for glb_op in global_ops: for glb_op in global_ops:
...@@ -786,3 +795,33 @@ class DistributeTranspiler: ...@@ -786,3 +795,33 @@ class DistributeTranspiler:
else: else:
iomap[key] = vars iomap[key] = vars
return iomap return iomap
def _get_lr_ops(self):
lr_ops = []
# find learning rate variables by optimize op
lr_vars = set()
for op in self.optimize_ops:
if self._is_opt_op(op):
lr_vars.add(op.input("LearningRate")[0])
find_ops = []
# find ops which output is lr var
block = self.program.global_block()
for op in block.ops:
if set(op.output_arg_names) & lr_vars:
find_ops.append(op)
# make a union find struct by the ops in default_main_program
ufind = UnionFind(block.ops)
for op1 in block.ops:
for op2 in block.ops:
# NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops.
if op1 != op2 and self._is_op_connected(op1, op2) and \
not self._is_opt_op(op1) and not self._is_opt_op(op2):
ufind.union(op1, op2)
# find all ops which is related with lr var
for op1 in block.ops:
for op2 in find_ops:
if ufind.is_connected(op1, op2):
lr_ops.append(op1)
return lr_ops
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册