diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index fcd767f53b3c457ba0e1364d1065840b18117a5d..5cb7eb98dee8a924e3db060639e7f0d1101fc8e9 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -901,28 +901,20 @@ class Completer: self._array_nodes[array_var_name].append(node.outputs[0]) if node.is_var() and node.var() is not None: if node.node.graph_id() != 0: - for before_node in reversed(all_nodes[:idx]): - if ( - before_node.is_var() - and before_node.var() is not None - and before_node.node.graph_id() - == node.node.graph_id() - 1 - and before_node.var().name() == node.var().name() - ): - self._node_pairs_between_graphs.append( - (before_node, node) - ) - for after_node in all_nodes[idx + 1 :]: - if ( - after_node.is_var() - and after_node.var() is not None - and after_node.node.graph_id() - == node.node.graph_id() - 1 - and after_node.var().name() == node.var().name() - ): + parent_nodes = ( + self._dist_context._tensor_nodes_with_same_name[ + node.node.graph_id() - 1 + ].get(node.var().name(), None) + ) + if parent_nodes is not None: + sorted_parent_nodes = sorted( + parent_nodes, key=lambda x: x[0] + ) + for _, parent_node in sorted_parent_nodes: self._node_pairs_between_graphs.append( - (after_node, node) + (parent_node, node) ) + self._has_prepared = True def complete_forward_annotation(self, serial_main_program=None): diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index b8a7bec6dad51d9f5804f7945daaa7274e21261b..d7f23e2c565ef567fdfc4cc0e0348d088d6c108d 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -634,21 +634,17 @@ class DistributedContext: ) def _order_nodes_by_program_order(self): - def _contains(nodes, target_node): - for node in nodes: - if _node_id(node) == _node_id(target_node): - return True - return False - serial_ordered_tensor_nodes = [] serial_ordered_op_nodes = [] all_nodes = [] + visited = {} for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for node in graph.all_nodes(): all_nodes.append(node) for node in all_nodes: if node.is_var() and node.var() is not None: serial_ordered_tensor_nodes.append(node) + visited[_node_id(node)] = False if node.is_op() and node.op() is not None: serial_ordered_op_nodes.append(node) serial_ordered_tensor_nodes.sort( @@ -670,10 +666,12 @@ class DistributedContext: if ( tensor_node.is_var() and tensor_node.var() is not None - and not _contains(new_serial_ordered_nodes, tensor_node) + and not visited[_node_id(tensor_node)] ): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True + tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.append(op_node) @@ -683,10 +681,11 @@ class DistributedContext: if ( tensor_node.is_var() and tensor_node.var() is not None - and not _contains(new_serial_ordered_nodes, tensor_node) + and not visited[_node_id(tensor_node)] ): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_tensor_nodes.sort( @@ -701,9 +700,28 @@ class DistributedContext: assert len(self._serial_ordered_nodes) == len( self._serial_ordered_tensor_nodes ) + len(self._serial_ordered_op_nodes) + # graph_id -> tensor->name -> node_lists + self._tensor_nodes_with_same_name = defaultdict(dict) + for idx, node in enumerate(self._serial_ordered_nodes): + if node.is_var() and node.var() is not None: + graph_id = node.node.graph_id() + tensor_name = node.var().name() + if ( + self._tensor_nodes_with_same_name[graph_id].get( + tensor_name, None + ) + is None + ): + self._tensor_nodes_with_same_name[graph_id][ + tensor_name + ] = [] + self._tensor_nodes_with_same_name[graph_id][tensor_name].append( + (idx, node) + ) + self._serial_orphan_tensor_nodes = [] for tensor_node in serial_ordered_tensor_nodes: - if not _contains(self._serial_ordered_tensor_nodes, tensor_node): + if not visited[_node_id(tensor_node)]: self._serial_orphan_tensor_nodes.append(tensor_node) if len(self._serial_ordered_nodes) != num_nodes_before: print( @@ -713,23 +731,30 @@ class DistributedContext: def _init_dist_attr_for_graph(self): # Convert program to graph and initialize the distributed attributes self._order_nodes_by_program_order() + self._tensor_original_id_to_id = {} + self._op_original_id_to_id = {} + for tensor_id, tensor in self._dist_tensors_for_program.items(): + original_id = tensor.serial_tensor.desc.original_id() + self._tensor_original_id_to_id[original_id] = tensor_id + for op_id, op in self._dist_ops_for_program.items(): + original_id = op.serial_op.desc.original_id() + self._op_original_id_to_id[original_id] = op_id for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for ( - cur_tensor_id, - cur_dist_tensor, - ) in self._dist_tensors_for_program.items(): - if ( - tensor_id == cur_tensor_id - or tensor_id - == cur_dist_tensor.serial_tensor.desc.original_id() - ): - dist_tensor = cur_dist_tensor - self._node_id_to_tensor_id[ - _node_id(node) - ] = cur_tensor_id + cur_dist_tensor = self._dist_tensors_for_program.get( + tensor_id, None + ) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get( + cur_tensor_id, None + ) + dist_tensor = cur_dist_tensor + self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id assert ( dist_tensor is not None ), "Tensor must have a distributed tensor after the initialization for program." @@ -743,16 +768,16 @@ class DistributedContext: if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for ( - cur_op_id, - cur_dist_op, - ) in self._dist_ops_for_program.items(): - if ( - op_id == cur_op_id - or op_id == cur_dist_op.serial_op.desc.original_id() - ): - dist_op = cur_dist_op - self._node_id_to_op_id[_node_id(node)] = cur_op_id + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get( + cur_op_id, None + ) + dist_op = cur_dist_op + self._node_id_to_op_id[_node_id(node)] = cur_op_id assert ( dist_op is not None ), "Operator must have a distributed operator after the initialization for program." @@ -775,16 +800,17 @@ class DistributedContext: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for ( - cur_tensor_id, - cur_dist_tensor, - ) in self._dist_tensors_for_program.items(): - if ( - tensor_id == cur_tensor_id - or tensor_id - == cur_dist_tensor.serial_tensor.desc.original_id() - ): - dist_tensor = cur_dist_tensor + cur_dist_tensor = self._dist_tensors_for_program.get( + tensor_id, None + ) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get( + cur_tensor_id, None + ) + dist_tensor = cur_dist_tensor assert ( dist_tensor is not None ), "Tensor must have a distributed tensor after the initialization for program." @@ -798,15 +824,15 @@ class DistributedContext: if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for ( - cur_op_id, - cur_dist_op, - ) in self._dist_ops_for_program.items(): - if ( - op_id == cur_op_id - or op_id == cur_dist_op.serial_op.desc.original_id() - ): - dist_op = cur_dist_op + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get( + cur_op_id, None + ) + dist_op = cur_dist_op assert ( dist_op is not None ), "Operator must have a distributed operator after the initialization for program." @@ -1014,6 +1040,7 @@ class DistributedContext: "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", "_pass_context", + "_tensor_nodes_with_same_name", ]: setattr(result, k, v) else: