未验证 提交 61996aec 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Speedup the completion process (#51035)

* [Auto Parallel] Speedup the completion process

* [Auto Parallel] Skip the property of dist_context when deepcopying

* [Auto Parallel] Remove the unnecessary print
上级 419597de
...@@ -901,28 +901,20 @@ class Completer: ...@@ -901,28 +901,20 @@ class Completer:
self._array_nodes[array_var_name].append(node.outputs[0]) self._array_nodes[array_var_name].append(node.outputs[0])
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
if node.node.graph_id() != 0: if node.node.graph_id() != 0:
for before_node in reversed(all_nodes[:idx]): parent_nodes = (
if ( self._dist_context._tensor_nodes_with_same_name[
before_node.is_var() node.node.graph_id() - 1
and before_node.var() is not None ].get(node.var().name(), None)
and before_node.node.graph_id()
== node.node.graph_id() - 1
and before_node.var().name() == node.var().name()
):
self._node_pairs_between_graphs.append(
(before_node, node)
) )
for after_node in all_nodes[idx + 1 :]: if parent_nodes is not None:
if ( sorted_parent_nodes = sorted(
after_node.is_var() parent_nodes, key=lambda x: x[0]
and after_node.var() is not None )
and after_node.node.graph_id() for _, parent_node in sorted_parent_nodes:
== node.node.graph_id() - 1
and after_node.var().name() == node.var().name()
):
self._node_pairs_between_graphs.append( self._node_pairs_between_graphs.append(
(after_node, node) (parent_node, node)
) )
self._has_prepared = True self._has_prepared = True
def complete_forward_annotation(self, serial_main_program=None): def complete_forward_annotation(self, serial_main_program=None):
......
...@@ -634,21 +634,17 @@ class DistributedContext: ...@@ -634,21 +634,17 @@ class DistributedContext:
) )
def _order_nodes_by_program_order(self): def _order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if _node_id(node) == _node_id(target_node):
return True
return False
serial_ordered_tensor_nodes = [] serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = [] serial_ordered_op_nodes = []
all_nodes = [] all_nodes = []
visited = {}
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes(): for node in graph.all_nodes():
all_nodes.append(node) all_nodes.append(node)
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
serial_ordered_tensor_nodes.append(node) serial_ordered_tensor_nodes.append(node)
visited[_node_id(node)] = False
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
serial_ordered_op_nodes.append(node) serial_ordered_op_nodes.append(node)
serial_ordered_tensor_nodes.sort( serial_ordered_tensor_nodes.sort(
...@@ -670,10 +666,12 @@ class DistributedContext: ...@@ -670,10 +666,12 @@ class DistributedContext:
if ( if (
tensor_node.is_var() tensor_node.is_var()
and tensor_node.var() is not None and tensor_node.var() is not None
and not _contains(new_serial_ordered_nodes, tensor_node) and not visited[_node_id(tensor_node)]
): ):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node)
visited[_node_id(tensor_node)] = True
tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_nodes.append(op_node) new_serial_ordered_nodes.append(op_node)
...@@ -683,10 +681,11 @@ class DistributedContext: ...@@ -683,10 +681,11 @@ class DistributedContext:
if ( if (
tensor_node.is_var() tensor_node.is_var()
and tensor_node.var() is not None and tensor_node.var() is not None
and not _contains(new_serial_ordered_nodes, tensor_node) and not visited[_node_id(tensor_node)]
): ):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node)
visited[_node_id(tensor_node)] = True
tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_tensor_nodes.sort( new_serial_ordered_tensor_nodes.sort(
...@@ -701,9 +700,28 @@ class DistributedContext: ...@@ -701,9 +700,28 @@ class DistributedContext:
assert len(self._serial_ordered_nodes) == len( assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes self._serial_ordered_tensor_nodes
) + len(self._serial_ordered_op_nodes) ) + len(self._serial_ordered_op_nodes)
# graph_id -> tensor->name -> node_lists
self._tensor_nodes_with_same_name = defaultdict(dict)
for idx, node in enumerate(self._serial_ordered_nodes):
if node.is_var() and node.var() is not None:
graph_id = node.node.graph_id()
tensor_name = node.var().name()
if (
self._tensor_nodes_with_same_name[graph_id].get(
tensor_name, None
)
is None
):
self._tensor_nodes_with_same_name[graph_id][
tensor_name
] = []
self._tensor_nodes_with_same_name[graph_id][tensor_name].append(
(idx, node)
)
self._serial_orphan_tensor_nodes = [] self._serial_orphan_tensor_nodes = []
for tensor_node in serial_ordered_tensor_nodes: for tensor_node in serial_ordered_tensor_nodes:
if not _contains(self._serial_ordered_tensor_nodes, tensor_node): if not visited[_node_id(tensor_node)]:
self._serial_orphan_tensor_nodes.append(tensor_node) self._serial_orphan_tensor_nodes.append(tensor_node)
if len(self._serial_ordered_nodes) != num_nodes_before: if len(self._serial_ordered_nodes) != num_nodes_before:
print( print(
...@@ -713,23 +731,30 @@ class DistributedContext: ...@@ -713,23 +731,30 @@ class DistributedContext:
def _init_dist_attr_for_graph(self): def _init_dist_attr_for_graph(self):
# Convert program to graph and initialize the distributed attributes # Convert program to graph and initialize the distributed attributes
self._order_nodes_by_program_order() self._order_nodes_by_program_order()
self._tensor_original_id_to_id = {}
self._op_original_id_to_id = {}
for tensor_id, tensor in self._dist_tensors_for_program.items():
original_id = tensor.serial_tensor.desc.original_id()
self._tensor_original_id_to_id[original_id] = tensor_id
for op_id, op in self._dist_ops_for_program.items():
original_id = op.serial_op.desc.original_id()
self._op_original_id_to_id[original_id] = op_id
for node in self.serial_ordered_nodes: for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
dist_tensor = None dist_tensor = None
tensor_id = node.node.original_desc_id() tensor_id = node.node.original_desc_id()
for ( cur_dist_tensor = self._dist_tensors_for_program.get(
cur_tensor_id, tensor_id, None
cur_dist_tensor, )
) in self._dist_tensors_for_program.items(): if cur_dist_tensor is not None:
if ( cur_tensor_id = tensor_id
tensor_id == cur_tensor_id else:
or tensor_id cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
== cur_dist_tensor.serial_tensor.desc.original_id() cur_dist_tensor = self._dist_tensors_for_program.get(
): cur_tensor_id, None
)
dist_tensor = cur_dist_tensor dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[ self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id
_node_id(node)
] = cur_tensor_id
assert ( assert (
dist_tensor is not None dist_tensor is not None
), "Tensor must have a distributed tensor after the initialization for program." ), "Tensor must have a distributed tensor after the initialization for program."
...@@ -743,14 +768,14 @@ class DistributedContext: ...@@ -743,14 +768,14 @@ class DistributedContext:
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
dist_op = None dist_op = None
op_id = node.node.original_desc_id() op_id = node.node.original_desc_id()
for ( cur_dist_op = self._dist_ops_for_program.get(op_id, None)
cur_op_id, if cur_dist_op is not None:
cur_dist_op, cur_op_id = op_id
) in self._dist_ops_for_program.items(): else:
if ( cur_op_id = self._op_original_id_to_id[op_id]
op_id == cur_op_id cur_dist_op = self._dist_ops_for_program.get(
or op_id == cur_dist_op.serial_op.desc.original_id() cur_op_id, None
): )
dist_op = cur_dist_op dist_op = cur_dist_op
self._node_id_to_op_id[_node_id(node)] = cur_op_id self._node_id_to_op_id[_node_id(node)] = cur_op_id
assert ( assert (
...@@ -775,15 +800,16 @@ class DistributedContext: ...@@ -775,15 +800,16 @@ class DistributedContext:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
dist_tensor = None dist_tensor = None
tensor_id = node.node.original_desc_id() tensor_id = node.node.original_desc_id()
for ( cur_dist_tensor = self._dist_tensors_for_program.get(
cur_tensor_id, tensor_id, None
cur_dist_tensor, )
) in self._dist_tensors_for_program.items(): if cur_dist_tensor is not None:
if ( cur_tensor_id = tensor_id
tensor_id == cur_tensor_id else:
or tensor_id cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
== cur_dist_tensor.serial_tensor.desc.original_id() cur_dist_tensor = self._dist_tensors_for_program.get(
): cur_tensor_id, None
)
dist_tensor = cur_dist_tensor dist_tensor = cur_dist_tensor
assert ( assert (
dist_tensor is not None dist_tensor is not None
...@@ -798,14 +824,14 @@ class DistributedContext: ...@@ -798,14 +824,14 @@ class DistributedContext:
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
dist_op = None dist_op = None
op_id = node.node.original_desc_id() op_id = node.node.original_desc_id()
for ( cur_dist_op = self._dist_ops_for_program.get(op_id, None)
cur_op_id, if cur_dist_op is not None:
cur_dist_op, cur_op_id = op_id
) in self._dist_ops_for_program.items(): else:
if ( cur_op_id = self._op_original_id_to_id[op_id]
op_id == cur_op_id cur_dist_op = self._dist_ops_for_program.get(
or op_id == cur_dist_op.serial_op.desc.original_id() cur_op_id, None
): )
dist_op = cur_dist_op dist_op = cur_dist_op
assert ( assert (
dist_op is not None dist_op is not None
...@@ -1014,6 +1040,7 @@ class DistributedContext: ...@@ -1014,6 +1040,7 @@ class DistributedContext:
"_backup_serial_main_program_stack", "_backup_serial_main_program_stack",
"_backup_serial_startup_program_stack", "_backup_serial_startup_program_stack",
"_pass_context", "_pass_context",
"_tensor_nodes_with_same_name",
]: ]:
setattr(result, k, v) setattr(result, k, v)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册