diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 53c9cbe23dd82af866658fe46d1d631b0a3b26f3..e070ea8d428831c490348fedbf1f8865fdb9910c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -247,7 +247,7 @@ class DistributeTranspiler(object): np.random.seed(self.origin_program.random_seed) np.random.shuffle(grad_var_mapping_items) - grad_name_to_send_dummy_out = dict() + self.grad_name_to_send_dummy_out = dict() for grad_varname, splited_vars in grad_var_mapping_items: eplist = ps_dispatcher.dispatch(splited_vars) @@ -271,7 +271,7 @@ class DistributeTranspiler(object): dummy_output = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - grad_name_to_send_dummy_out[grad_varname] = dummy_output + self.grad_name_to_send_dummy_out[grad_varname] = dummy_output # get send op_role_var, if not splited, the grad should have .trainer suffix # if splited, grad should be the original grad var name (split_by_ref and send @@ -297,7 +297,12 @@ class DistributeTranspiler(object): if self.sync_mode: send_barrier_out = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - input_deps = grad_name_to_send_dummy_out.values() + if self.has_distributed_lookup_table: + self.grad_name_to_send_dummy_out[ + self.table_name] = program.global_block().create_var( + name=framework.generate_control_dev_var_name()) + input_deps = self.grad_name_to_send_dummy_out.values() + program.global_block().append_op( type="send_barrier", inputs={"X": list(input_deps)}, @@ -329,7 +334,7 @@ class DistributeTranspiler(object): recv_dep_in = send_barrier_out else: # connect deps to send op in async mode - recv_dep_in = grad_name_to_send_dummy_out[ + recv_dep_in = self.grad_name_to_send_dummy_out[ self.param_name_to_grad_name[param_varname]] all_recv_outputs.extend(splited_var) # get recv op_role_var, if not splited, the grad should have .trainer suffix @@ -1046,9 +1051,13 @@ class DistributeTranspiler(object): index=op_index + 2, type="send", inputs={'X': self.trainer_side_table_grad_list}, - outputs={'Out': []}, + outputs={ + 'Out': + [self.grad_name_to_send_dummy_out[self.table_name]] + if self.sync_mode else [] + }, attrs={ - "sync_mode": True, + "sync_mode": False, "epmap": pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [