_auto_parallel_context.py 14.4 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context of auto parallel"""
import threading
17 18
import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
Z
zhunaipan 已提交
19
from mindspore._c_expression import AutoParallelContext
20
from mindspore._checkparam import args_type_check
Z
zhunaipan 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223


class _AutoParallelContext:
    """
    _AutoParallelContext is the environment in which operations are executed

    Note:
        Create a context through instantiating Context object is not recommended.
        Should use auto_parallel_context() to get the context since Context is singleton.
    """
    _instance = None
    _instance_lock = threading.Lock()

    def __init__(self):
        self._context_handle = AutoParallelContext.get_instance()

    def __new__(cls):
        if cls._instance is None:
            cls._instance_lock.acquire()
            cls._instance = object.__new__(cls)
            cls._instance_lock.release()
        return cls._instance

    def check_context_handle(self):
        """
        Check context handle.

        Raises:
            ValueError: If the context handle is none.
        """
        if self._context_handle is None:
            raise ValueError("Context handle is none in context!!!")

    def set_device_num(self, device_num):
        """
        Set device num for auto parallel.

        Args:
            device_num (int): The device number.

        Raises:
            ValueError: If the device num is not in [1, 4096].
        """
        self.check_context_handle()
        if device_num < 1 or device_num > 4096:
            raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
        self._context_handle.set_device_num(device_num)

    def get_device_num(self):
        """Get device num."""
        self.check_context_handle()
        return self._context_handle.get_device_num()

    def set_global_rank(self, global_rank):
        """
        Set global rank for auto parallel.

        Args:
            global_rank (int): The rank id of current rank.

        Raises:
            ValueError: If the global rank is not in [1, 4096].
        """
        self.check_context_handle()
        if global_rank < 0 or global_rank > 4095:
            raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
        self._context_handle.set_global_rank(global_rank)

    def get_global_rank(self):
        """Get current rank id."""
        self.check_context_handle()
        return self._context_handle.get_global_rank()

    def set_mirror_mean(self, mirror_mean):
        """
        Set mirror_mean flag.

        Note:
            If mirror_mean is true, it will insert a div operator after parameter gradients allreduce.

        Args:
            mirror_mean (bool): The mirror_mean flag.
        """
        self.check_context_handle()
        self._context_handle.set_mirror_mean(mirror_mean)

    def get_mirror_mean(self):
        """Get mirror_mean flag."""
        self.check_context_handle()
        return self._context_handle.get_mirror_mean()

    def set_cast_before_mirror(self, cast_before_mirror):
        """
        Set cast_before_mirror.

        Note:
            If cast_before_mirror is true,
            it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.

        Args:
            cast_before_mirror (bool): The cast_before_mirror flag.
        """
        self.check_context_handle()
        self._context_handle.set_cast_before_mirror(cast_before_mirror)

    def get_cast_before_mirror(self):
        """Get cast_before_mirror flag."""
        self.check_context_handle()
        return self._context_handle.get_cast_before_mirror()

    def set_loss_repeated_mean(self, loss_repeated_mean):
        """
        Set loss_repeated_mean flag.

        Note:
            If loss_repeated_mean is true,
            Distributed automatic differentiation will perform a mean operator
            in backward in the case of repeated calculations.

        Args:
            loss_repeated_mean (bool): The loss_repeated_mean flag.
        """
        self.check_context_handle()
        self._context_handle.set_loss_repeated_mean(loss_repeated_mean)

    def get_loss_repeated_mean(self):
        """Get loss_repeated_mean flag."""
        self.check_context_handle()
        return self._context_handle.get_loss_repeated_mean()

    def set_communication_backend(self, communication_backend):
        """
        Set communication backend.

        Args:
            communication_backend (str): The communication backend.
        """
        self.check_context_handle()
        self._context_handle.set_communication_backend(communication_backend)

    def get_communication_backend(self):
        """Get communication backend."""
        self.check_context_handle()
        return self._context_handle.get_communication_backend()

    def set_parallel_mode(self, parallel_mode):
        """
        Set parallel mode for auto parallel.

        Args:
            parallel_mode (str): The parallel mode of auto parallel.

        Raises:
            ValueError: If parallel mode is not supported.
        """
        self.check_context_handle()
        ret = self._context_handle.set_parallel_mode(parallel_mode)
        if ret is False:
            raise ValueError("Parallel mode does not support {}".format(parallel_mode))

    def get_parallel_mode(self):
        """Get parallel mode."""
        self.check_context_handle()
        return self._context_handle.get_parallel_mode()

    def set_strategy_search_mode(self, strategy_search_mode):
        self.check_context_handle()
        ret = self._context_handle.set_strategy_search_mode(strategy_search_mode)
        if ret is False:
            raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode))

    def get_strategy_search_mode(self):
        self.check_context_handle()
        return self._context_handle.get_strategy_search_mode()

    def set_parameter_broadcast(self, parameter_broadcast):
        """
        Set parameter broadcast.

        Args:
            parameter_broadcast (bool): Parameter broadcast or not.
        """
        self.check_context_handle()
        self._context_handle.set_parameter_broadcast(parameter_broadcast)

    def get_parameter_broadcast(self):
        """Get parameter broadcast flag."""
        self.check_context_handle()
        return self._context_handle.get_parameter_broadcast()

    def get_parameter_broadcast_is_set(self):
        """Get parameter broadcast is set or not."""
        self.check_context_handle()
        return self._context_handle.get_parameter_broadcast_is_set()

    def set_all_reduce_fusion_split_indices(self, indices):
        """
        Set allreduce fusion strategy by parameters indices.

        Args:
            indices (list): Indices list.

        Raises:
224
            TypeError: If type of indices item is not int.
Z
zhunaipan 已提交
225 226 227 228 229
        """
        self.check_context_handle()
        for index in indices:
            if not isinstance(index, int):
                raise TypeError('indices has invalid value')
230 231 232
        self._context_handle.set_all_reduce_fusion_split_indices(indices)
        if context.get_context("device_target") == "Ascend":
            _set_fusion_strategy_by_idx(indices)
Z
zhunaipan 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246

    def get_all_reduce_fusion_split_indices(self):
        """Get allreduce fusion split indices."""
        self.check_context_handle()
        return self._context_handle.get_all_reduce_fusion_split_indices()

    def set_all_reduce_fusion_split_sizes(self, sizes):
        """
        Set allreduce fusion strategy by parameters data sizes.

        Args:
            sizes (list): Sizes list.

        Raises:
247
            TypeError: If type of sizes item is not int.
Z
zhunaipan 已提交
248 249 250 251 252
        """
        self.check_context_handle()
        for size in sizes:
            if not isinstance(size, int):
                raise TypeError('sizes has invalid value')
253 254 255
        self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
        if context.get_context("device_target") == "Ascend":
            _set_fusion_strategy_by_size(sizes)
Z
zhunaipan 已提交
256 257 258 259 260 261

    def get_all_reduce_fusion_split_sizes(self):
        """Get allreduce fusion split sizes."""
        self.check_context_handle()
        return self._context_handle.get_all_reduce_fusion_split_sizes()

L
lirongzhen 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
        """
        Set enable/disable all reduce fusion.

        Args:
            enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
        """
        self.check_context_handle()
        if not isinstance(enable_all_reduce_fusion, bool):
            raise TypeError('enable_all_reduce_fusion is invalid type')
        self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)

    def get_enable_all_reduce_fusion(self):
        """Get all reduce fusion flag."""
        self.check_context_handle()
        return self._context_handle.get_enable_all_reduce_fusion()

Z
zhunaipan 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
    def get_device_num_is_set(self):
        """Get device number is set or not."""
        self.check_context_handle()
        return self._context_handle.get_device_num_is_set()

    def get_global_rank_is_set(self):
        """Get global rank is set or not."""
        self.check_context_handle()
        return self._context_handle.get_global_rank_is_set()

    def reset(self):
        """Reset all settings."""
        self.check_context_handle()
        self._context_handle.reset()


_auto_parallel_context = None


def auto_parallel_context():
    """
    Get the global _auto_parallel_context, if it is not created, create a new one.

    Returns:
        _AutoParallelContext, the global auto parallel context.
    """
    global _auto_parallel_context
    if _auto_parallel_context is None:
        _auto_parallel_context = _AutoParallelContext()
    return _auto_parallel_context


_set_auto_parallel_context_func_map = {
    "device_num": auto_parallel_context().set_device_num,
    "global_rank": auto_parallel_context().set_global_rank,
    "mirror_mean": auto_parallel_context().set_mirror_mean,
    "cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
    "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
    "parallel_mode": auto_parallel_context().set_parallel_mode,
    "parameter_broadcast": auto_parallel_context().set_parameter_broadcast}


_get_auto_parallel_context_func_map = {
    "device_num": auto_parallel_context().get_device_num,
    "global_rank": auto_parallel_context().get_global_rank,
    "mirror_mean": auto_parallel_context().get_mirror_mean,
    "cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
    "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
    "parallel_mode": auto_parallel_context().get_parallel_mode,
    "parameter_broadcast": auto_parallel_context().get_parameter_broadcast}


@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
                 loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool)
def _set_auto_parallel_context(**kwargs):
    """
    Set auto parallel context.

    Note:
        Attribute name is required for setting attributes.

    Args:
        device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
        mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
        loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
                          calculations. Default: True.
        cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
        parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
                     "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".

                     - stand_alone: Only one processor working.

                     - data_parallel: Distributing the data across different processors.

                     - hybrid_parallel: Achieving data parallelism and model parallelism manually.

                     - semi_auto_parallel: Achieving data parallelism and model parallelism by
                       setting parallel strategies.

                     - auto_parallel: Achieving parallelism automatically.
        parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
                       "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
                       broadcast. Default: False.

    Raises:
        ValueError: If input key is not attribute in auto parallel context.
    """
    for key, value in kwargs.items():
        if key not in _set_auto_parallel_context_func_map:
            raise ValueError("Set context keyword %s is not recognized!" % key)
        set_func = _set_auto_parallel_context_func_map[key]
        set_func(value)


def _get_auto_parallel_context(attr_key):
    """
    Get auto parallel context attribute value according to the key.

    Args:
        attr_key (str): The key of the attribute.

    Returns:
        Return attribute value according to the key.

    Raises:
        ValueError: If input key is not attribute in auto parallel context.
    """
    if attr_key not in _get_auto_parallel_context_func_map:
        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
    get_func = _get_auto_parallel_context_func_map[attr_key]
    return get_func()


def _reset_auto_parallel_context():
    """
    Reset auto parallel context attributes to the default values:

    - device_num: 1.
    - global_rank: 0.
    - mirror_mean: False.
    - cast_before_mirror: True.
    - parallel_mode: "stand_alone".
    - parameter_broadcast: False.
    """
    auto_parallel_context().reset()