process_mesh_v2.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16

17
from paddle.framework import core
18 19 20 21


class ProcessMesh(core.ProcessMesh):
    r"""
22
    The class `Processmesh` describes the topology of logical processes.
23 24

    Args:
C
chenxujun 已提交
25
        mesh (list|numpy.array): an N-dimensional array describes the topology
26 27 28
            of logical processes.
        dim_names (list, optional): the i-th element of this list gives the name of the
            i-th dimension.
29

30 31 32 33 34 35
    Returns:
        None

    Examples:
        .. code-block:: python

36 37 38
            >>> import paddle
            >>> import paddle.distributed as dist
            >>> paddle.enable_static()
39

40 41 42
            >>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
            >>> assert mesh.shape == [2, 3]
            >>> assert mesh.process_ids == [2, 4, 5, 0, 1, 3]
43 44 45 46

    """

    def __init__(self, mesh, dim_names=None):
47
        if not isinstance(mesh, list) and not isinstance(mesh, np.ndarray):
48
            raise ValueError(
49 50
                'The mesh must be an instance of list or np.ndarray.'
            )
51 52 53 54 55 56 57 58
        if isinstance(mesh, list):
            mesh = np.array(mesh)

        self._mesh = mesh

        self._shape = list(self._mesh.shape)

        self._process_ids = self._mesh.flatten().tolist()
59 60 61 62 63 64
        assert all(
            isinstance(p, int) for p in self._process_ids
        ), "All elements of the mesh must be integer"
        assert (
            min(self._process_ids) >= 0
        ), 'All elements of the mesh must be >= 0.'
65 66
        unique_process_ids = set(self._process_ids)
        assert len(unique_process_ids) == len(
67 68
            self._process_ids
        ), 'All elements of the mesh must be unique.'
69 70

        if dim_names is not None:
71 72 73
            assert len(dim_names) == len(
                self._shape
            ), "The length of dims_names must be same as the shape of the mesh."
74 75 76 77 78
            self._dim_names = dim_names
        else:
            self._dim_names = ["d" + str(i) for i in range(len(self._shape))]

        # Follow the requirement for using pybind11
79 80 81
        core.ProcessMesh.__init__(
            self, self._shape, self._process_ids, self._dim_names
        )
82 83 84 85 86 87

    @property
    def mesh(self):
        return self._mesh


88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
def compute_compatible_process_mesh(process_meshes):
    """Compute the compatible process mesh given a list of process meshes."""
    if not process_meshes:
        return None

    def _compute_compatible_of_two_process_meshes(pm1, pm2):
        if pm1 is None:
            return True, pm2
        if pm2 is None:
            return True, pm1
        if pm1 == pm2:
            return True, pm1
        if pm1.process_ids == pm2.process_ids:
            if len(pm1.shape) >= len(pm2.shape):
                return True, pm1
            else:
                return True, pm2
        process_set1 = set(pm1.process_ids)
        process_set2 = set(pm2.process_ids)
        if process_set1.issubset(process_set2):
            return True, pm2
        if process_set2.issubset(process_set1):
            return True, pm1
        return False, None

    compatible_result = None
    for process_mesh in process_meshes:
115 116 117 118 119 120
        (
            compatible,
            compatible_result,
        ) = _compute_compatible_of_two_process_meshes(
            compatible_result, process_mesh
        )
121 122 123 124 125 126
        if not compatible:
            return None
    if compatible_result.empty():
        return None
    if isinstance(compatible_result, core.ProcessMesh):
        mesh = np.array(compatible_result.process_ids).reshape(
127 128
            compatible_result.shape
        )
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        return ProcessMesh(mesh, compatible_result.dim_names)
    elif isinstance(compatible_result, ProcessMesh):
        return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
    else:
        raise ValueError("Unrecognized ProcessMesh.")


def merge_process_mesh(process_meshes):
    """Merge a list of process meshes."""
    merged_process_mesh = None
    merged_process_ids = set()
    for process_mesh in process_meshes:
        if process_mesh is not None:
            process_ids = set(process_mesh.process_ids)
            merged_process_ids = merged_process_ids.union(process_ids)
    if len(merged_process_ids) != 0:
        merged_process_mesh = ProcessMesh(list(merged_process_ids))
    return merged_process_mesh