未验证 提交 61996aec 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Speedup the completion process (#51035)

* [Auto Parallel] Speedup the completion process

* [Auto Parallel] Skip the property of dist_context when deepcopying

* [Auto Parallel] Remove the unnecessary print
上级 419597de
......@@ -901,28 +901,20 @@ class Completer:
self._array_nodes[array_var_name].append(node.outputs[0])
if node.is_var() and node.var() is not None:
if node.node.graph_id() != 0:
for before_node in reversed(all_nodes[:idx]):
if (
before_node.is_var()
and before_node.var() is not None
and before_node.node.graph_id()
== node.node.graph_id() - 1
and before_node.var().name() == node.var().name()
):
self._node_pairs_between_graphs.append(
(before_node, node)
)
for after_node in all_nodes[idx + 1 :]:
if (
after_node.is_var()
and after_node.var() is not None
and after_node.node.graph_id()
== node.node.graph_id() - 1
and after_node.var().name() == node.var().name()
):
parent_nodes = (
self._dist_context._tensor_nodes_with_same_name[
node.node.graph_id() - 1
].get(node.var().name(), None)
)
if parent_nodes is not None:
sorted_parent_nodes = sorted(
parent_nodes, key=lambda x: x[0]
)
for _, parent_node in sorted_parent_nodes:
self._node_pairs_between_graphs.append(
(after_node, node)
(parent_node, node)
)
self._has_prepared = True
def complete_forward_annotation(self, serial_main_program=None):
......
......@@ -634,21 +634,17 @@ class DistributedContext:
)
def _order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if _node_id(node) == _node_id(target_node):
return True
return False
serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = []
all_nodes = []
visited = {}
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes():
all_nodes.append(node)
for node in all_nodes:
if node.is_var() and node.var() is not None:
serial_ordered_tensor_nodes.append(node)
visited[_node_id(node)] = False
if node.is_op() and node.op() is not None:
serial_ordered_op_nodes.append(node)
serial_ordered_tensor_nodes.sort(
......@@ -670,10 +666,12 @@ class DistributedContext:
if (
tensor_node.is_var()
and tensor_node.var() is not None
and not _contains(new_serial_ordered_nodes, tensor_node)
and not visited[_node_id(tensor_node)]
):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
visited[_node_id(tensor_node)] = True
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_nodes.append(op_node)
......@@ -683,10 +681,11 @@ class DistributedContext:
if (
tensor_node.is_var()
and tensor_node.var() is not None
and not _contains(new_serial_ordered_nodes, tensor_node)
and not visited[_node_id(tensor_node)]
):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
visited[_node_id(tensor_node)] = True
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_tensor_nodes.sort(
......@@ -701,9 +700,28 @@ class DistributedContext:
assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes
) + len(self._serial_ordered_op_nodes)
# graph_id -> tensor->name -> node_lists
self._tensor_nodes_with_same_name = defaultdict(dict)
for idx, node in enumerate(self._serial_ordered_nodes):
if node.is_var() and node.var() is not None:
graph_id = node.node.graph_id()
tensor_name = node.var().name()
if (
self._tensor_nodes_with_same_name[graph_id].get(
tensor_name, None
)
is None
):
self._tensor_nodes_with_same_name[graph_id][
tensor_name
] = []
self._tensor_nodes_with_same_name[graph_id][tensor_name].append(
(idx, node)
)
self._serial_orphan_tensor_nodes = []
for tensor_node in serial_ordered_tensor_nodes:
if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
if not visited[_node_id(tensor_node)]:
self._serial_orphan_tensor_nodes.append(tensor_node)
if len(self._serial_ordered_nodes) != num_nodes_before:
print(
......@@ -713,23 +731,30 @@ class DistributedContext:
def _init_dist_attr_for_graph(self):
# Convert program to graph and initialize the distributed attributes
self._order_nodes_by_program_order()
self._tensor_original_id_to_id = {}
self._op_original_id_to_id = {}
for tensor_id, tensor in self._dist_tensors_for_program.items():
original_id = tensor.serial_tensor.desc.original_id()
self._tensor_original_id_to_id[original_id] = tensor_id
for op_id, op in self._dist_ops_for_program.items():
original_id = op.serial_op.desc.original_id()
self._op_original_id_to_id[original_id] = op_id
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
for (
cur_tensor_id,
cur_dist_tensor,
) in self._dist_tensors_for_program.items():
if (
tensor_id == cur_tensor_id
or tensor_id
== cur_dist_tensor.serial_tensor.desc.original_id()
):
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[
_node_id(node)
] = cur_tensor_id
cur_dist_tensor = self._dist_tensors_for_program.get(
tensor_id, None
)
if cur_dist_tensor is not None:
cur_tensor_id = tensor_id
else:
cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
cur_dist_tensor = self._dist_tensors_for_program.get(
cur_tensor_id, None
)
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id
assert (
dist_tensor is not None
), "Tensor must have a distributed tensor after the initialization for program."
......@@ -743,16 +768,16 @@ class DistributedContext:
if node.is_op() and node.op() is not None:
dist_op = None
op_id = node.node.original_desc_id()
for (
cur_op_id,
cur_dist_op,
) in self._dist_ops_for_program.items():
if (
op_id == cur_op_id
or op_id == cur_dist_op.serial_op.desc.original_id()
):
dist_op = cur_dist_op
self._node_id_to_op_id[_node_id(node)] = cur_op_id
cur_dist_op = self._dist_ops_for_program.get(op_id, None)
if cur_dist_op is not None:
cur_op_id = op_id
else:
cur_op_id = self._op_original_id_to_id[op_id]
cur_dist_op = self._dist_ops_for_program.get(
cur_op_id, None
)
dist_op = cur_dist_op
self._node_id_to_op_id[_node_id(node)] = cur_op_id
assert (
dist_op is not None
), "Operator must have a distributed operator after the initialization for program."
......@@ -775,16 +800,17 @@ class DistributedContext:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
for (
cur_tensor_id,
cur_dist_tensor,
) in self._dist_tensors_for_program.items():
if (
tensor_id == cur_tensor_id
or tensor_id
== cur_dist_tensor.serial_tensor.desc.original_id()
):
dist_tensor = cur_dist_tensor
cur_dist_tensor = self._dist_tensors_for_program.get(
tensor_id, None
)
if cur_dist_tensor is not None:
cur_tensor_id = tensor_id
else:
cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
cur_dist_tensor = self._dist_tensors_for_program.get(
cur_tensor_id, None
)
dist_tensor = cur_dist_tensor
assert (
dist_tensor is not None
), "Tensor must have a distributed tensor after the initialization for program."
......@@ -798,15 +824,15 @@ class DistributedContext:
if node.is_op() and node.op() is not None:
dist_op = None
op_id = node.node.original_desc_id()
for (
cur_op_id,
cur_dist_op,
) in self._dist_ops_for_program.items():
if (
op_id == cur_op_id
or op_id == cur_dist_op.serial_op.desc.original_id()
):
dist_op = cur_dist_op
cur_dist_op = self._dist_ops_for_program.get(op_id, None)
if cur_dist_op is not None:
cur_op_id = op_id
else:
cur_op_id = self._op_original_id_to_id[op_id]
cur_dist_op = self._dist_ops_for_program.get(
cur_op_id, None
)
dist_op = cur_dist_op
assert (
dist_op is not None
), "Operator must have a distributed operator after the initialization for program."
......@@ -1014,6 +1040,7 @@ class DistributedContext:
"_backup_serial_main_program_stack",
"_backup_serial_startup_program_stack",
"_pass_context",
"_tensor_nodes_with_same_name",
]:
setattr(result, k, v)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册