未验证 提交 991cedb4 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #11702 from Yancey1989/fix_async_update_failed

Fix async update failed
...@@ -163,7 +163,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -163,7 +163,8 @@ void ListenAndServOp::RunSyncLoop(
} }
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const { framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
...@@ -190,6 +191,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -190,6 +191,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx; grad_to_prepared_ctx;
...@@ -317,7 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -317,7 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list); RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program, &recv_scope);
} }
} }
......
...@@ -50,7 +50,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -50,7 +50,8 @@ class ListenAndServOp : public framework::OperatorBase {
const std::vector<int>& prefetch_block_id_list) const; const std::vector<int>& prefetch_block_id_list) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const; framework::ProgramDesc* program,
framework::Scope* recv_scope) const;
void SavePort() const; void SavePort() const;
......
...@@ -1299,16 +1299,6 @@ class DistributeTranspiler(object): ...@@ -1299,16 +1299,6 @@ class DistributeTranspiler(object):
ufind.union(op1, op2) ufind.union(op1, op2)
return ufind return ufind
def _is_opt_role_op(self, op):
# NOTE: depend on oprole to find out whether this op is for
# optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attrs and \
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def _is_optimizer_op(self, op): def _is_optimizer_op(self, op):
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
...@@ -1399,7 +1389,10 @@ class DistributeTranspiler(object): ...@@ -1399,7 +1389,10 @@ class DistributeTranspiler(object):
params_grads = [] params_grads = []
origin_var_dict = self.origin_program.global_block().vars origin_var_dict = self.origin_program.global_block().vars
for op in block.ops: for op in block.ops:
if self._is_opt_role_op(op): # NOTE(Yancey1989): we can not use op role to distinguish an optimizer op
# or not, because all ops in optimizer sub-graph would
# sign the optimizer op role
if self._is_optimizer_op(op):
opt_ops.append(op) opt_ops.append(op)
# HACK(wuyi): if we find grad vars from input of optimize # HACK(wuyi): if we find grad vars from input of optimize
# ops, we may get the output of clip op. Use syntax "@GRAD" # ops, we may get the output of clip op. Use syntax "@GRAD"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册