diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 52d5c607bbc57e5ae65e3a4593585c0294d2b74f..8437042a67cbd2e2cbf90bad785131bce825c982 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -1738,8 +1738,18 @@ class Resharder: if len(set(process_mesh.processes)) == len(processes): global_process_mesh_idx = idx break + if global_process_mesh_idx is not None: - self.dist_context.process_meshes.pop(idx) + is_removed = False + global_mesh = self.dist_context.process_meshes[idx] + for i, mesh in enumerate(self.dist_context.process_meshes): + if i == idx: + continue + if set(mesh.processes) < set(global_mesh.processes): + is_removed = True + + if is_removed: + self.dist_context.process_meshes.pop(idx) def _change_subblock_op_input_and_output(self, block_idx, block): if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: