未验证 提交 020d13c1 编写于 作者: Q Qiao Longfei 提交者: GitHub

fix dist table send hang problem (#13259)

* fix dist table send hang problem

* revert sync_mode config

* fix async send table
上级 2c31ea92
...@@ -247,7 +247,7 @@ class DistributeTranspiler(object): ...@@ -247,7 +247,7 @@ class DistributeTranspiler(object):
np.random.seed(self.origin_program.random_seed) np.random.seed(self.origin_program.random_seed)
np.random.shuffle(grad_var_mapping_items) np.random.shuffle(grad_var_mapping_items)
grad_name_to_send_dummy_out = dict() self.grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items: for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
...@@ -271,7 +271,7 @@ class DistributeTranspiler(object): ...@@ -271,7 +271,7 @@ class DistributeTranspiler(object):
dummy_output = program.global_block().create_var( dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix # get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send # if splited, grad should be the original grad var name (split_by_ref and send
...@@ -297,7 +297,12 @@ class DistributeTranspiler(object): ...@@ -297,7 +297,12 @@ class DistributeTranspiler(object):
if self.sync_mode: if self.sync_mode:
send_barrier_out = program.global_block().create_var( send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
input_deps = grad_name_to_send_dummy_out.values() if self.has_distributed_lookup_table:
self.grad_name_to_send_dummy_out[
self.table_name] = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = self.grad_name_to_send_dummy_out.values()
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
inputs={"X": list(input_deps)}, inputs={"X": list(input_deps)},
...@@ -329,7 +334,7 @@ class DistributeTranspiler(object): ...@@ -329,7 +334,7 @@ class DistributeTranspiler(object):
recv_dep_in = send_barrier_out recv_dep_in = send_barrier_out
else: else:
# connect deps to send op in async mode # connect deps to send op in async mode
recv_dep_in = grad_name_to_send_dummy_out[ recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix # get recv op_role_var, if not splited, the grad should have .trainer suffix
...@@ -1046,9 +1051,13 @@ class DistributeTranspiler(object): ...@@ -1046,9 +1051,13 @@ class DistributeTranspiler(object):
index=op_index + 2, index=op_index + 2,
type="send", type="send",
inputs={'X': self.trainer_side_table_grad_list}, inputs={'X': self.trainer_side_table_grad_list},
outputs={'Out': []}, outputs={
'Out':
[self.grad_name_to_send_dummy_out[self.table_name]]
if self.sync_mode else []
},
attrs={ attrs={
"sync_mode": True, "sync_mode": False,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册