diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index cfad35ffe1bfad4e7a66b0ccc9a0cf8b96d0c2ea..54aaa138a18c9cd943c4ab145ea6477e27946e8b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -388,49 +388,84 @@ class DistributeTranspiler(object): sparse_update_ops.append(op) return sparse_update_ops - def _update_remote_sparse_update_op(self, program, param_varname, - height_sections, endpoints, - table_names): - - ops = [] - op_type = "" - - for op in self.sparse_update_ops: - if param_varname in op.input_arg_names and op_type == "": - op_type = op.type - ops.append(op) - - elif param_varname in op.input_arg_names and op_type == op.type: - ops.append(op) - - if op_type == "lookup_table": - all_ops = program.global_block().ops - op_idxs = [all_ops.index(op) for op in ops] - inputs = [ - program.global_block().vars[op.input("Ids")[0]] for op in ops - ] - w = program.global_block().vars[ops[0].input("W")[0]] - padding_idx = ops[0].attr("padding_idx") - outputs = [ - program.global_block().vars[op.output("Out")[0]] for op in ops - ] + def _update_remote_sparse_update_op(self, program, + need_sparse_update_params): + + for param_varname, attrs in need_sparse_update_params.items(): + height_sections = self.sparse_param_to_height_sections[ + param_varname] + endpoints = attrs[0] + table_names = attrs[1] + + ops = [] + op_type = "" + used_ops = [] + + for idx, op in enumerate(self.sparse_update_ops): + if param_varname in op.input_arg_names and op_type == "": + op_type = op.type + ops.append(op) + used_ops.append(idx) + + elif param_varname in op.input_arg_names and op_type == op.type: + ops.append(op) + used_ops.append(idx) + + if op_type == "lookup_table": + all_ops = program.global_block().ops + op_idxs = [all_ops.index(op) for op in ops] + inputs = [ + program.global_block().vars[op.input("Ids")[0]] + for op in ops + ] + w = program.global_block().vars[ops[0].input("W")[0]] + padding_idx = ops[0].attr("padding_idx") + outputs = [ + program.global_block().vars[op.output("Out")[0]] + for op in ops + ] - for idx in op_idxs[::-1]: - program.global_block()._remove_op(idx) + for idx in op_idxs[::-1]: + program.global_block()._remove_op(idx) + + inputs_idxs = [-1] * len(inputs) + outputs_idxs = [-1] * len(outputs) + + for idx, op in enumerate(program.global_block().ops): + for i in range(0, len(op.output_names)): + outs = op.output(op.output_names[i]) + for in_id, in_var in enumerate(inputs): + if in_var.name in outs: + inputs_idxs[in_id] = idx + for i in range(0, len(op.input_names)): + ins = op.input(op.input_names[i]) + for out_id, out_var in enumerate(outputs): + if out_var.name in ins: + outputs_idxs[out_id] = idx + + if min(outputs_idxs) - max(inputs_idxs) >= 1: + distributed_idx = max(inputs_idxs) + 1 + + program.global_block()._insert_op( + index=distributed_idx, + type="distributed_lookup_table", + inputs={"Ids": inputs, + 'W': w}, + outputs={"Outputs": outputs}, + attrs={ + "table_names": table_names, + "height_sections": height_sections, + "endpoints": endpoints, + "padding_idx": padding_idx, + "trainer_id": self.trainer_id + }) + else: + raise ValueError( + "something wrong with distribute_transpiler, submit a issue is recommended" + ) - program.global_block()._insert_op( - index=op_idxs[0], - type="distributed_lookup_table", - inputs={"Ids": inputs, - 'W': w}, - outputs={"Outputs": outputs}, - attrs={ - "table_names": table_names, - "height_sections": height_sections, - "endpoints": endpoints, - "padding_idx": padding_idx, - "trainer_id": self.trainer_id - }) + for idx in used_ops[::-1]: + self.sparse_update_ops.pop(idx) def _is_input_of_remote_sparse_update_op(self, param_name): for op in self.sparse_update_ops: @@ -681,6 +716,8 @@ class DistributeTranspiler(object): recv_vars[i].name) distributed_var.endpoint = ep + need_sparse_update_params = {} + # step4: Concat the parameters splits together after recv. all_recv_outputs = [] for param_varname, splited_var in six.iteritems(self.param_var_mapping): @@ -712,10 +749,7 @@ class DistributeTranspiler(object): table_name) distributed_var.vtype = "RemotePrefetch" - height_sections = self.sparse_param_to_height_sections[ - param_varname] - self._update_remote_sparse_update_op( - program, param_varname, height_sections, eps, table_names) + need_sparse_update_params[param_varname] = (eps, table_names) else: recv_varnames = [] if self.config.runtime_split_send_recv: @@ -764,6 +798,9 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) + self._update_remote_sparse_update_op(program, + need_sparse_update_params) + self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) if self.has_distributed_lookup_table: