diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 2205484bdb951b9a1452bf8daf2c9ab21ea747c3..afe9f13e3ab43a4f845a1d708588ab03b0e39557 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -168,7 +168,10 @@ class ProcessMesh(object): else: new_mesh = self._mesh[index] new_dim_names = self._dim_names[1:] - return ProcessMesh(new_mesh, new_dim_names) + if new_mesh.shape: + return ProcessMesh(new_mesh, new_dim_names) + else: + return ProcessMesh([new_mesh]) def __enter__(self): set_current_process_mesh(self) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py index c9419f8c855afb8b5ef9bbef41fadd608be8c65a..ce38780564b5be4141547f51c5025eeae4d8855f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py @@ -101,6 +101,12 @@ class TestProcessMesh(unittest.TestCase): self.assertEqual(sub_process_mesh4.dim_names, ["d0"]) self.assertEqual(sub_process_mesh4.ndim, 1) + sub_process_mesh5 = sub_process_mesh3[0] + self.assertEqual(sub_process_mesh5.shape, [1]) + self.assertEqual(sub_process_mesh5.process_ids, [1]) + self.assertEqual(sub_process_mesh5.dim_names, ["d0"]) + self.assertEqual(sub_process_mesh5.ndim, 1) + def test_context_manager(self): mesh = np.array([1, 2, 3, 4]) input = static.data(name="input",