未验证 提交 8e9c719d 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix reshard when train with eval (#46605)

* [AutoParallel] fix reshard when train with eval

* fix mppp
上级 40ab6faf
......@@ -1738,8 +1738,18 @@ class Resharder:
if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx
break
if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx)
is_removed = False
global_mesh = self.dist_context.process_meshes[idx]
for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx:
continue
if set(mesh.processes) < set(global_mesh.processes):
is_removed = True
if is_removed:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册