未验证 提交 020d13c1 编写于 作者: Q Qiao Longfei 提交者: GitHub

fix dist table send hang problem (#13259)

* fix dist table send hang problem

* revert sync_mode config

* fix async send table
上级 2c31ea92
......@@ -247,7 +247,7 @@ class DistributeTranspiler(object):
np.random.seed(self.origin_program.random_seed)
np.random.shuffle(grad_var_mapping_items)
grad_name_to_send_dummy_out = dict()
self.grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars)
......@@ -271,7 +271,7 @@ class DistributeTranspiler(object):
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
......@@ -297,7 +297,12 @@ class DistributeTranspiler(object):
if self.sync_mode:
send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = grad_name_to_send_dummy_out.values()
if self.has_distributed_lookup_table:
self.grad_name_to_send_dummy_out[
self.table_name] = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = self.grad_name_to_send_dummy_out.values()
program.global_block().append_op(
type="send_barrier",
inputs={"X": list(input_deps)},
......@@ -329,7 +334,7 @@ class DistributeTranspiler(object):
recv_dep_in = send_barrier_out
else:
# connect deps to send op in async mode
recv_dep_in = grad_name_to_send_dummy_out[
recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
......@@ -1046,9 +1051,13 @@ class DistributeTranspiler(object):
index=op_index + 2,
type="send",
inputs={'X': self.trainer_side_table_grad_list},
outputs={'Out': []},
outputs={
'Out':
[self.grad_name_to_send_dummy_out[self.table_name]]
if self.sync_mode else []
},
attrs={
"sync_mode": True,
"sync_mode": False,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册