From 8e9c719d7dd5fb95102b9cd5a91f2a0e15b1df2a Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 29 Sep 2022 11:16:24 +0800 Subject: [PATCH] [AutoParallel] fix reshard when train with eval (#46605) * [AutoParallel] fix reshard when train with eval * fix mppp --- python/paddle/distributed/auto_parallel/reshard.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 52d5c607bbc..8437042a67c 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -1738,8 +1738,18 @@ class Resharder: if len(set(process_mesh.processes)) == len(processes): global_process_mesh_idx = idx break + if global_process_mesh_idx is not None: - self.dist_context.process_meshes.pop(idx) + is_removed = False + global_mesh = self.dist_context.process_meshes[idx] + for i, mesh in enumerate(self.dist_context.process_meshes): + if i == idx: + continue + if set(mesh.processes) < set(global_mesh.processes): + is_removed = True + + if is_removed: + self.dist_context.process_meshes.pop(idx) def _change_subblock_op_input_and_output(self, block_idx, block): if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: -- GitLab