test_process_mesh_v2.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import unittest
16

17
from paddle.distributed.auto_parallel.static.process_mesh_v2 import (
18 19 20 21
    ProcessMesh,
    compute_compatible_process_mesh,
    merge_process_mesh,
)
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45


class TestProcessMesh(unittest.TestCase):
    def test_process_mesh(self):
        mesh = [[0, 1, 2], [3, 4, 5]]
        mesh2 = [[0, 1], [2, 3]]
        process_mesh = ProcessMesh(mesh, dim_names=["x", "y"])
        process_mesh2 = ProcessMesh(mesh2)
        self.assertEqual(process_mesh.shape, [2, 3])
        self.assertEqual(process_mesh.process_ids, [0, 1, 2, 3, 4, 5])
        self.assertEqual(process_mesh.dim_names, ["x", "y"])
        self.assertEqual(process_mesh.size, 6)
        self.assertEqual(process_mesh.ndim, 2)
        self.assertEqual(process_mesh.dim_size(0), 2)
        self.assertEqual(process_mesh.dim_size(-1), 3)
        self.assertEqual(process_mesh.dim_size("x"), 2)
        self.assertEqual(process_mesh.dim_size("y"), 3)
        self.assertEqual(process_mesh.empty(), False)
        self.assertEqual(process_mesh.contains(0), True)
        self.assertEqual(process_mesh.contains(6), False)
        self.assertEqual(process_mesh, process_mesh)
        self.assertNotEqual(process_mesh, process_mesh2)
        self.assertEqual(str(process_mesh), str(process_mesh))

46
    def test_compute_compatible_process_mesh(self):
47 48 49
        process_mesh1 = ProcessMesh(
            [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
        )
50
        compatible_process_mesh = compute_compatible_process_mesh(
51 52
            [process_mesh1, None]
        )
53 54
        self.assertEqual(compatible_process_mesh, process_mesh1)
        compatible_process_mesh = compute_compatible_process_mesh(
55 56
            [None, process_mesh1]
        )
57 58 59 60
        self.assertEqual(compatible_process_mesh, process_mesh1)

        process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
        compatible_process_mesh = compute_compatible_process_mesh(
61 62
            [process_mesh1, process_mesh2]
        )
63 64 65 66 67
        self.assertEqual(compatible_process_mesh, process_mesh1)
        self.assertEqual(compatible_process_mesh, process_mesh2)

        process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
        compatible_process_mesh = compute_compatible_process_mesh(
68 69
            [process_mesh1, process_mesh2]
        )
70 71 72 73
        self.assertEqual(compatible_process_mesh, process_mesh1)

        process_mesh2 = ProcessMesh([[0, 1, 2]])
        compatible_process_mesh = compute_compatible_process_mesh(
74 75
            [process_mesh1, process_mesh2]
        )
76 77 78
        self.assertEqual(compatible_process_mesh, process_mesh1)

    def test_merge_process_mesh(self):
79 80 81
        process_mesh1 = ProcessMesh(
            [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
        )
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        merged_process_mesh = merge_process_mesh([process_mesh1, None])
        self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
        merged_process_mesh = merge_process_mesh([None, process_mesh1])
        self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))

        process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
        merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
        self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))

        process_mesh2 = ProcessMesh([[0, 1, 2]])
        merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
        self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))

        process_mesh2 = ProcessMesh([[6, 7]])
        merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
97 98 99
        self.assertEqual(
            merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7])
        )
100

101 102 103

if __name__ == "__main__":
    unittest.main()