未验证 提交 7a7826b7 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix process_mesh (#46583)

上级 e65cdaee
...@@ -168,7 +168,10 @@ class ProcessMesh(object): ...@@ -168,7 +168,10 @@ class ProcessMesh(object):
else: else:
new_mesh = self._mesh[index] new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:] new_dim_names = self._dim_names[1:]
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names) return ProcessMesh(new_mesh, new_dim_names)
else:
return ProcessMesh([new_mesh])
def __enter__(self): def __enter__(self):
set_current_process_mesh(self) set_current_process_mesh(self)
......
...@@ -101,6 +101,12 @@ class TestProcessMesh(unittest.TestCase): ...@@ -101,6 +101,12 @@ class TestProcessMesh(unittest.TestCase):
self.assertEqual(sub_process_mesh4.dim_names, ["d0"]) self.assertEqual(sub_process_mesh4.dim_names, ["d0"])
self.assertEqual(sub_process_mesh4.ndim, 1) self.assertEqual(sub_process_mesh4.ndim, 1)
sub_process_mesh5 = sub_process_mesh3[0]
self.assertEqual(sub_process_mesh5.shape, [1])
self.assertEqual(sub_process_mesh5.process_ids, [1])
self.assertEqual(sub_process_mesh5.dim_names, ["d0"])
self.assertEqual(sub_process_mesh5.ndim, 1)
def test_context_manager(self): def test_context_manager(self):
mesh = np.array([1, 2, 3, 4]) mesh = np.array([1, 2, 3, 4])
input = static.data(name="input", input = static.data(name="input",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册