process_mesh.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import numpy as np
16
import copy
17
import paddle
18

19 20 21
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
J
JZ-LIANG 已提交
22 23
# {shape_process_ids : unique_id}
_g_unique_process_mesh_map = {}
24 25


26 27 28
def get_current_process_mesh():
    global _g_current_process_mesh
    return _g_current_process_mesh
29

30 31 32 33 34 35 36 37 38 39 40 41

def set_current_process_mesh(process_mesh):
    global _g_previous_process_mesh
    global _g_current_process_mesh
    _g_previous_process_mesh = _g_current_process_mesh
    _g_current_process_mesh = process_mesh


def reset_current_process_mesh():
    global _g_previous_process_mesh
    global _g_current_process_mesh
    _g_current_process_mesh = _g_previous_process_mesh
42 43


J
JZ-LIANG 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
def get_unique_id_for_process_mesh(shape, process_ids):
    key = f"shape {shape}, process_ids {process_ids}"
    global _g_unique_process_mesh_map
    if key in _g_unique_process_mesh_map:
        unique_id = _g_unique_process_mesh_map[key]
    else:
        unique_id = len(_g_unique_process_mesh_map) + 1
        _g_unique_process_mesh_map[key] = unique_id

    return unique_id


def retrive_unique_id_for_process_mesh(shape, process_ids):
    key = f"shape {shape}, process_ids {process_ids}"
    global _g_unique_process_mesh_map
    assert key in _g_unique_process_mesh_map
    return _g_unique_process_mesh_map[key]


def get_unique_process_mesh_map():
    global _g_unique_process_mesh_map
    return _g_unique_process_mesh_map


68
class ProcessMesh(object):
69
    """
70
    The `Processmesh` object describes the topology of the used processes.
71 72

    Args:
73 74 75 76
        mesh (list|numpy.array): an n-dimensional array describes the toplogy
            of the processes.
        dim_names (list, optional): the i-th element of this list gives the name of the
            i-th dimension of the mesh.
77

78 79 80 81
    Examples:
        .. code-block:: python

            import paddle
82

83 84 85
            mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
            assert mesh.shape == [2, 3]
            assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
86 87 88

    """

89 90 91 92 93 94 95 96
    def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None):
        # Use shape and process_ids just for compatibility
        # Users should not use these directly
        if mesh is None:
            assert shape is not None
            assert process_ids is not None
            mesh = np.array(process_ids).reshape(shape)

97
        if not isinstance(mesh, list) and not isinstance(mesh, np.ndarray):
98
            raise ValueError(
99 100
                'The mesh must be an instance of list or np.ndarray.'
            )
101 102 103 104 105 106 107
        if isinstance(mesh, list):
            mesh = np.array(mesh)

        self._mesh = mesh
        self._shape = list(self._mesh.shape)
        self._process_ids = self._mesh.flatten().tolist()

108 109 110 111 112 113
        assert all(
            isinstance(p, int) for p in self._process_ids
        ), "All elements of the mesh must be integer"
        assert (
            min(self._process_ids) >= 0
        ), 'All elements of the mesh must be >= 0.'
114 115
        unique_process_ids = set(self._process_ids)
        assert len(unique_process_ids) == len(
116 117
            self._process_ids
        ), 'All elements of the mesh must be unique.'
118 119

        if dim_names is not None:
120 121 122
            assert len(dim_names) == len(
                self._shape
            ), "The length of dims_names must be same as the shape of the mesh."
123 124 125 126
            self._dim_names = copy.deepcopy(dim_names)
        else:
            self._dim_names = ["d" + str(i) for i in range(len(self._shape))]
        unique_dim_names = set(self._dim_names)
127 128 129 130 131 132 133 134
        assert len(unique_dim_names) == len(
            self._dim_names
        ), 'All dim_names {} must be unique.'.format(dim_names)

        # # Store all process meshes
        # from .dist_context import get_default_distributed_context
        # default_dist_cxt = get_default_distributed_context()
        # default_dist_cxt.add_process_mesh(self)
135

136
        # Add new processes to process group 0
137
        from .process_group import get_process_group
138

139 140
        pg0 = get_process_group(0)
        pg0.add_ranks(self.processes)
141

J
JZ-LIANG 已提交
142 143 144 145 146
        # Uniqe Mesh Id
        self._unique_id = get_unique_id_for_process_mesh(
            self._shape, self._process_ids
        )

147
    @property
148 149 150
    def shape(self):
        """
        Get the shape of this ProcessMesh.
151
        """
152
        return self._shape
153 154

    @property
155 156 157
    def process_ids(self):
        """
        Get the process ids belonging to this ProcessMesh.
158
        """
159 160 161 162 163 164 165 166
        return self._process_ids

    @property
    def dim_names(self):
        """
        Get the dimension names of this ProcessMesh.
        """
        return self._dim_names
167 168 169 170

    @property
    def ndim(self):
        """
171 172 173 174 175 176 177 178 179 180 181
        Get the number of dimension of this ProcessMesh.
        """
        return len(self._shape)

    @property
    def mesh(self):
        """
        Get the underlying mesh of ProcessMesh.
        """
        return self._mesh

J
JZ-LIANG 已提交
182 183 184 185 186 187 188 189 190 191
    @property
    def unique_id(self):
        """
        Get the unique id of ProcessMesh.
        NOTE
        Unique id only take process_ids and shape into account.
        Different ProcessMesh with same process_ids and shape have same unique id.
        """
        return self._unique_id

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    @property
    def topology(self):
        return self._shape

    @property
    def processes(self):
        return self._process_ids

    def __getitem__(self, index):
        if isinstance(index, tuple):
            new_dim_names = []
            for i, item in enumerate(index):
                if isinstance(item, slice):
                    new_dim_names.append(self._dim_names[i])
            new_mesh = self._mesh[index]
            if new_mesh.shape:
                return ProcessMesh(new_mesh, new_dim_names)
            else:
                # Wrap a scalar into a list but without dim_names
                return ProcessMesh([new_mesh])
        elif isinstance(index, slice):
            new_mesh = self._mesh[index]
            new_dim_names = self._dim_names
            return ProcessMesh(new_mesh, new_dim_names)
        else:
            new_mesh = self._mesh[index]
            new_dim_names = self._dim_names[1:]
219 220 221 222
            if new_mesh.shape:
                return ProcessMesh(new_mesh, new_dim_names)
            else:
                return ProcessMesh([new_mesh])
223 224 225 226 227 228 229 230 231 232 233

    def __enter__(self):
        set_current_process_mesh(self)
        default_prog = paddle.fluid.default_main_program()
        cur_block = default_prog.current_block()
        self._old_var_names = list(cur_block.vars.keys())
        self._old_op_size = len(cur_block.ops)

    def __exit__(self, exc_type, exc_value, exc_traceback):
        from .dist_tensor import DistributedTensor
        from .dist_op import DistributedOperator
234

235 236 237 238 239
        default_prog = paddle.fluid.default_main_program()
        cur_block = default_prog.current_block()
        new_var_names = list(cur_block.vars.keys())
        new_op_size = len(cur_block.ops)
        from .dist_context import get_default_distributed_context
240

241 242 243 244 245
        default_dist_ctx = get_default_distributed_context()
        for name in new_var_names:
            if name not in self._old_var_names:
                tensor = cur_block.vars[name]
                dist_tensor = default_dist_ctx.get_dist_tensor_for_program(
246 247
                    tensor
                )
248
                if dist_tensor is None:
249 250 251
                    dist_tensor = DistributedTensor(
                        cur_block.vars[name], {"process_mesh": self}
                    )
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
                    dist_tensor.dist_attr.mark_annotated("process_mesh")
                    default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
                else:
                    if dist_tensor.dist_attr.process_mesh is None:
                        dist_tensor.dist_attr.process_mesh = self
                        dist_tensor.dist_attr.mark_annotated("process_mesh")

        for idx in range(self._old_op_size, new_op_size):
            op = cur_block.ops[idx]
            dist_op = default_dist_ctx.get_dist_op_for_program(op)
            if dist_op is None:
                dist_op = DistributedOperator(op, {"process_mesh": self})
                dist_op.dist_attr.mark_annotated("process_mesh")
                default_dist_ctx.add_dist_op_for_program(dist_op)
            else:
                if dist_op.dist_attr.process_mesh is None:
                    dist_op.dist_attr.process_mesh = self
                    dist_op.dist_attr.mark_annotated("process_mesh")
        reset_current_process_mesh()
271 272 273 274

    def __eq__(self, other):
        if not isinstance(other, ProcessMesh):
            return False
275
        if self.shape != other.shape or self.process_ids != other.process_ids:
276 277 278 279 280 281 282
            return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)

    def __str__(self):
283
        str = "shape {}, process_ids {}, dim_nams {}".format(
284 285
            self.shape, self.process_ids, self.dim_names
        )
286
        return str