From f0af2708959a9606dd0ec9e0a6f59f84c28c9608 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Mon, 17 Oct 2022 16:12:47 +0800 Subject: [PATCH] [Auto Parallel] Fix the bug of completion (#47056) * [Auto Parallel] Fix the bug for None labels * [Auto Parallel] Fix the completion bug --- python/paddle/distributed/auto_parallel/completion.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index a4bee7a4ad..db387ef2bb 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -363,7 +363,14 @@ class Completer: def _update_dims_mapping_for_special(self): # Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it op_nodes = self._dist_context._serial_ordered_op_nodes + # NOTE: this list may be changed if Paddle changes the existing rules. + related_reader_ops = [ + "create_py_reader", "create_double_buffer_reader", "read" + ] for op_node in op_nodes: + if op_node.op() is not None \ + and op_node.op().type() in related_reader_ops: + continue op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) for tensor_node in op_node.outputs: if tensor_node.is_var() and tensor_node.var() is not None: @@ -403,6 +410,7 @@ class Completer: reach_fix_point = False else: reach_fix_point = True + # NOTE: this will be removed after changing the reshard rule self._update_dims_mapping_for_special() def _update_process_mesh_by_nearest(self, op_node, nearest_op_node): -- GitLab