From 7a7826b7487ae34a7394596c425ec9e98d592ebd Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 28 Sep 2022 20:40:07 +0800 Subject: [PATCH] [AutoParallel] fix process_mesh (#46583) --- python/paddle/distributed/auto_parallel/process_mesh.py | 5 ++++- .../tests/unittests/auto_parallel/test_process_mesh.py | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 2205484bdb..afe9f13e3a 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -168,7 +168,10 @@ class ProcessMesh(object): else: new_mesh = self._mesh[index] new_dim_names = self._dim_names[1:] - return ProcessMesh(new_mesh, new_dim_names) + if new_mesh.shape: + return ProcessMesh(new_mesh, new_dim_names) + else: + return ProcessMesh([new_mesh]) def __enter__(self): set_current_process_mesh(self) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py index c9419f8c85..ce38780564 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py @@ -101,6 +101,12 @@ class TestProcessMesh(unittest.TestCase): self.assertEqual(sub_process_mesh4.dim_names, ["d0"]) self.assertEqual(sub_process_mesh4.ndim, 1) + sub_process_mesh5 = sub_process_mesh3[0] + self.assertEqual(sub_process_mesh5.shape, [1]) + self.assertEqual(sub_process_mesh5.process_ids, [1]) + self.assertEqual(sub_process_mesh5.dim_names, ["d0"]) + self.assertEqual(sub_process_mesh5.ndim, 1) + def test_context_manager(self): mesh = np.array([1, 2, 3, 4]) input = static.data(name="input", -- GitLab